diff --git a/.flake8 b/.flake8 old mode 100644 new mode 100755 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs old mode 100644 new mode 100755 diff --git a/.github/scripts/.gitignore b/.github/scripts/.gitignore old mode 100644 new mode 100755 diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile old mode 100644 new mode 100755 diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/audioset.yml b/.github/workflows/audioset.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/baker_zh.yml b/.github/workflows/baker_zh.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/ksponspeech.yml b/.github/workflows/ksponspeech.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/librispeech.yml b/.github/workflows/librispeech.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/rknn.yml b/.github/workflows/rknn.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-multi-corpora-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/test-ncnn-export.yml b/.github/workflows/test-ncnn-export.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml old mode 100644 new mode 100755 diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml old mode 100644 new mode 100755 diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100755 index 000000000..56d6da264 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,22 @@ +{ + "workbench.colorCustomizations": { + "activityBar.activeBackground": "#93e6fc", + "activityBar.background": "#93e6fc", + "activityBar.foreground": "#15202b", + "activityBar.inactiveForeground": "#15202b99", + "activityBarBadge.background": "#fa45d4", + "activityBarBadge.foreground": "#15202b", + "commandCenter.border": "#15202b99", + "sash.hoverBorder": "#93e6fc", + "statusBar.background": "#61dafb", + "statusBar.foreground": "#15202b", + "statusBarItem.hoverBackground": "#2fcefa", + "statusBarItem.remoteBackground": "#61dafb", + "statusBarItem.remoteForeground": "#15202b", + "titleBar.activeBackground": "#61dafb", + "titleBar.activeForeground": "#15202b", + "titleBar.inactiveBackground": "#61dafb99", + "titleBar.inactiveForeground": "#15202b99" + }, + "peacock.remoteColor": "#61dafb" +} \ No newline at end of file diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/Testing/Temporary/LastTest.log b/Testing/Temporary/LastTest.log new file mode 100644 index 000000000..4836f71d5 --- /dev/null +++ b/Testing/Temporary/LastTest.log @@ -0,0 +1,3 @@ +Start testing: Aug 21 17:10 KST +---------------------------------------------------------- +End testing: Aug 21 17:10 KST diff --git a/contributing.md b/contributing.md old mode 100644 new mode 100755 diff --git a/docker/README.md b/docker/README.md old mode 100644 new mode 100755 diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile old mode 100644 new mode 100755 diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.2.1-cuda11.8.dockerfile b/docker/torch2.2.1-cuda11.8.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.2.1-cuda12.1.dockerfile b/docker/torch2.2.1-cuda12.1.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.2.2-cuda11.8.dockerfile b/docker/torch2.2.2-cuda11.8.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.2.2-cuda12.1.dockerfile b/docker/torch2.2.2-cuda12.1.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.3.1-cuda11.8.dockerfile b/docker/torch2.3.1-cuda11.8.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.3.1-cuda12.1.dockerfile b/docker/torch2.3.1-cuda12.1.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.4.0-cuda11.8.dockerfile b/docker/torch2.4.0-cuda11.8.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.4.0-cuda12.1.dockerfile b/docker/torch2.4.0-cuda12.1.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.4.0-cuda12.4.dockerfile b/docker/torch2.4.0-cuda12.4.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.4.1-cuda11.8.dockerfile b/docker/torch2.4.1-cuda11.8.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.4.1-cuda12.1.dockerfile b/docker/torch2.4.1-cuda12.1.dockerfile old mode 100644 new mode 100755 diff --git a/docker/torch2.4.1-cuda12.4.dockerfile b/docker/torch2.4.1-cuda12.4.dockerfile old mode 100644 new mode 100755 diff --git a/docs/.gitignore b/docs/.gitignore old mode 100644 new mode 100755 diff --git a/docs/Makefile b/docs/Makefile old mode 100644 new mode 100755 diff --git a/docs/README.md b/docs/README.md old mode 100644 new mode 100755 diff --git a/docs/make.bat b/docs/make.bat old mode 100644 new mode 100755 diff --git a/docs/requirements.txt b/docs/requirements.txt old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav b/docs/source/_static/kaldi-align/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/at.wav b/docs/source/_static/kaldi-align/at.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/beside.wav b/docs/source/_static/kaldi-align/beside.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/curiosity.wav b/docs/source/_static/kaldi-align/curiosity.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/had.wav b/docs/source/_static/kaldi-align/had.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/i.wav b/docs/source/_static/kaldi-align/i.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/me.wav b/docs/source/_static/kaldi-align/me.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/moment.wav b/docs/source/_static/kaldi-align/moment.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/that.wav b/docs/source/_static/kaldi-align/that.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/kaldi-align/this.wav b/docs/source/_static/kaldi-align/this.wav old mode 100644 new mode 100755 diff --git a/docs/source/_static/logo.png b/docs/source/_static/logo.png old mode 100644 new mode 100755 diff --git a/docs/source/conf.py b/docs/source/conf.py old mode 100644 new mode 100755 diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst old mode 100644 new mode 100755 diff --git a/docs/source/contributing/doc.rst b/docs/source/contributing/doc.rst old mode 100644 new mode 100755 diff --git a/docs/source/contributing/how-to-create-a-recipe.rst b/docs/source/contributing/how-to-create-a-recipe.rst old mode 100644 new mode 100755 diff --git a/docs/source/contributing/images/doc-contrib.png b/docs/source/contributing/images/doc-contrib.png old mode 100644 new mode 100755 diff --git a/docs/source/contributing/images/pre-commit-check-success.png b/docs/source/contributing/images/pre-commit-check-success.png old mode 100644 new mode 100755 diff --git a/docs/source/contributing/images/pre-commit-check.png b/docs/source/contributing/images/pre-commit-check.png old mode 100644 new mode 100755 diff --git a/docs/source/contributing/index.rst b/docs/source/contributing/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst old mode 100644 new mode 100755 diff --git a/docs/source/decoding-with-langugage-models/index.rst b/docs/source/decoding-with-langugage-models/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/decoding-with-langugage-models/rescoring.rst b/docs/source/decoding-with-langugage-models/rescoring.rst old mode 100644 new mode 100755 diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst old mode 100644 new mode 100755 diff --git a/docs/source/docker/img/docker-hub.png b/docs/source/docker/img/docker-hub.png old mode 100644 new mode 100755 diff --git a/docs/source/docker/index.rst b/docs/source/docker/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst old mode 100644 new mode 100755 diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst old mode 100644 new mode 100755 diff --git a/docs/source/for-dummies/data-preparation.rst b/docs/source/for-dummies/data-preparation.rst old mode 100644 new mode 100755 diff --git a/docs/source/for-dummies/decoding.rst b/docs/source/for-dummies/decoding.rst old mode 100644 new mode 100755 diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst old mode 100644 new mode 100755 diff --git a/docs/source/for-dummies/index.rst b/docs/source/for-dummies/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/for-dummies/model-export.rst b/docs/source/for-dummies/model-export.rst old mode 100644 new mode 100755 diff --git a/docs/source/for-dummies/training.rst b/docs/source/for-dummies/training.rst old mode 100644 new mode 100755 diff --git a/docs/source/fst-based-forced-alignment/diff.rst b/docs/source/fst-based-forced-alignment/diff.rst old mode 100644 new mode 100755 diff --git a/docs/source/fst-based-forced-alignment/index.rst b/docs/source/fst-based-forced-alignment/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/fst-based-forced-alignment/k2-based.rst b/docs/source/fst-based-forced-alignment/k2-based.rst old mode 100644 new mode 100755 diff --git a/docs/source/fst-based-forced-alignment/kaldi-based.rst b/docs/source/fst-based-forced-alignment/kaldi-based.rst old mode 100644 new mode 100755 diff --git a/docs/source/huggingface/index.rst b/docs/source/huggingface/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/huggingface/pic/hugging-face-sherpa-2.png b/docs/source/huggingface/pic/hugging-face-sherpa-2.png old mode 100644 new mode 100755 diff --git a/docs/source/huggingface/pic/hugging-face-sherpa-3.png b/docs/source/huggingface/pic/hugging-face-sherpa-3.png old mode 100644 new mode 100755 diff --git a/docs/source/huggingface/pic/hugging-face-sherpa.png b/docs/source/huggingface/pic/hugging-face-sherpa.png old mode 100644 new mode 100755 diff --git a/docs/source/huggingface/pretrained-models.rst b/docs/source/huggingface/pretrained-models.rst old mode 100644 new mode 100755 diff --git a/docs/source/huggingface/spaces.rst b/docs/source/huggingface/spaces.rst old mode 100644 new mode 100755 diff --git a/docs/source/index.rst b/docs/source/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/installation/images/README.md b/docs/source/installation/images/README.md old mode 100644 new mode 100755 diff --git a/docs/source/installation/images/device-CPU_CUDA-orange.svg b/docs/source/installation/images/device-CPU_CUDA-orange.svg old mode 100644 new mode 100755 diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg old mode 100644 new mode 100755 diff --git a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg old mode 100644 new mode 100755 diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg old mode 100644 new mode 100755 diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg old mode 100644 new mode 100755 diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/export-model-state-dict-pretrained-out.txt b/docs/source/model-export/code/export-model-state-dict-pretrained-out.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-model-state-dict.rst b/docs/source/model-export/export-model-state-dict.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/export-with-torch-jit-trace.rst b/docs/source/model-export/export-with-torch-jit-trace.rst old mode 100644 new mode 100755 diff --git a/docs/source/model-export/index.rst b/docs/source/model-export/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Finetune/adapter/finetune_adapter.rst b/docs/source/recipes/Finetune/adapter/finetune_adapter.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst b/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Finetune/index.rst b/docs/source/recipes/Finetune/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/index.rst b/docs/source/recipes/Non-streaming-ASR/aishell/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst b/docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/index.rst b/docs/source/recipes/Non-streaming-ASR/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/timit/index.rst b/docs/source/recipes/Non-streaming-ASR/timit/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/yesno/index.rst b/docs/source/recipes/Non-streaming-ASR/yesno/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst b/docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/RNN-LM/index.rst b/docs/source/recipes/RNN-LM/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/index.rst b/docs/source/recipes/Streaming-ASR/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png b/docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/index.rst b/docs/source/recipes/Streaming-ASR/librispeech/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst old mode 100644 new mode 100755 diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst old mode 100644 new mode 100755 diff --git a/egs/aidatatang_200zh/ASR/README.md b/egs/aidatatang_200zh/ASR/README.md deleted file mode 100644 index 035139d17..000000000 --- a/egs/aidatatang_200zh/ASR/README.md +++ /dev/null @@ -1,38 +0,0 @@ -Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/375 -# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall. -The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2. -## Training procedure -The main repositories are list below, we will update the training and decoding scripts with the update of version. -k2: https://github.com/k2-fsa/k2 -icefall: https://github.com/k2-fsa/icefall -lhotse: https://github.com/lhotse-speech/lhotse -* Install k2 and lhotse, k2 installation guide refers to https://k2-fsa.github.io/k2/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall. -* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above. -``` -git clone https://github.com/k2-fsa/icefall -cd icefall -``` -* Preparing data. -``` -cd egs/aidatatang_200zh/ASR -bash ./prepare.sh -``` -* Training -``` -export CUDA_VISIBLE_DEVICES="0,1" -./pruned_transducer_stateless2/train.py \ - --world-size 2 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 250 -``` -## Evaluation results -The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29. -The WERs are -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 | -| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 | -| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500| diff --git a/egs/aidatatang_200zh/ASR/RESULTS.md b/egs/aidatatang_200zh/ASR/RESULTS.md deleted file mode 100644 index 5b82fb61f..000000000 --- a/egs/aidatatang_200zh/ASR/RESULTS.md +++ /dev/null @@ -1,72 +0,0 @@ -## Results - -### Aidatatang_200zh Char training results (Pruned Transducer Stateless2) - -#### 2022-05-16 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/375. - -The WERs are - -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 | -| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 | -| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500| - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1" - -./pruned_transducer_stateless2/train.py \ - --world-size 2 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 250 \ - --save-every-n 1000 - -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/xS7kgYf2RwyDpQAOdS8rAA/#scalars - -The decoding command is: -``` -epoch=29 -avg=19 - -## greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir ./data/lang_char \ - --max-duration 100 - -## modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir ./data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -## fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir ./data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -A pre-trained model and decoding logs can be found at diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py deleted file mode 100755 index 9caacb78b..000000000 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aidatatang_200zh dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False): - src_dir = Path("data/manifests/aidatatang_200zh") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "train", - "dev", - "test", - ) - prefix = "aidatatang" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - - for sup in m["supervisions"]: - sup.custom = {"origin": "aidatatang_200zh"} - - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_aidatatang_200zh( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed - ) diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_musan.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py b/egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py deleted file mode 100644 index d66e5cfca..000000000 --- a/egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -See the function `remove_short_and_long_utt()` -in ../../../librispeech/ASR/transducer/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - paths = [ - "./data/fbank/aidatatang_cuts_train.jsonl.gz", - "./data/fbank/aidatatang_cuts_dev.jsonl.gz", - "./data/fbank/aidatatang_cuts_test.jsonl.gz", - ] - - for path in paths: - print(f"Starting display the statistics for {path}") - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Starting display the statistics for ./data/fbank/aidatatang_cuts_train.jsonl.gz -Cuts count: 494715 -Total duration (hours): 422.6 -Speech duration (hours): 422.6 (100.0%) -*** -Duration statistics (seconds): -mean 3.1 -std 1.2 -min 1.0 -25% 2.3 -50% 2.7 -75% 3.5 -99% 7.2 -99.5% 8.0 -99.9% 9.5 -max 18.1 -Starting display the statistics for ./data/fbank/aidatatang_cuts_dev.jsonl.gz -Cuts count: 24216 -Total duration (hours): 20.2 -Speech duration (hours): 20.2 (100.0%) -*** -Duration statistics (seconds): -mean 3.0 -std 1.0 -min 1.2 -25% 2.3 -50% 2.7 -75% 3.4 -99% 6.7 -99.5% 7.3 -99.9% 8.8 -max 11.3 -Starting display the statistics for ./data/fbank/aidatatang_cuts_test.jsonl.gz -Cuts count: 48144 -Total duration (hours): 40.2 -Speech duration (hours): 40.2 (100.0%) -*** -Duration statistics (seconds): -mean 3.0 -std 1.1 -min 0.9 -25% 2.3 -50% 2.6 -75% 3.4 -99% 6.9 -99.5% 7.5 -99.9% 9.0 -max 21.8 -""" diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py deleted file mode 100755 index 6b440dfb3..000000000 --- a/egs/aidatatang_200zh/ASR/local/prepare_char.py +++ /dev/null @@ -1,244 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/text, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import re -from pathlib import Path -from typing import Dict, List - -import k2 -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: - """Check if all the given tokens are in token symbol table. - - Args: - token_sym_table: - Token symbol table that contains all the valid tokens. - tokens: - A list of tokens. - Returns: - Return True if there is any token not in the token_sym_table, - otherwise False. - """ - for tok in tokens: - if tok not in token_sym_table: - return True - return False - - -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: - """Generate a lexicon from a word list and token_sym_table. - - Args: - token_sym_table: - Token symbol table that mapping token to token ids. - words: - A list of strings representing words. - Returns: - Return a dict whose keys are words and values are the corresponding - tokens. - """ - lexicon = [] - for word in words: - chars = list(word.strip(" \t")) - if contain_oov(token_sym_table, chars): - continue - lexicon.append((word, chars)) - - # The OOV word is - lexicon.append(("", [""])) - return lexicon - - -def generate_tokens(text_file: str) -> Dict[str, int]: - """Generate tokens from the given text file. - - Args: - text_file: - A file that contains text lines to generate tokens. - Returns: - Return a dict whose keys are tokens and values are token ids ranged - from 0 to len(keys) - 1. - """ - tokens: Dict[str, int] = dict() - tokens[""] = 0 - tokens[""] = 1 - tokens[""] = 2 - whitespace = re.compile(r"([ \t\r\n]+)") - with open(text_file, "r", encoding="utf-8") as f: - for line in f: - line = re.sub(whitespace, "", line) - chars = list(line) - for char in chars: - if char not in tokens: - tokens[char] = len(tokens) - return tokens - - -def main(): - lang_dir = Path("data/lang_char") - text_file = lang_dir / "text" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - token_sym_table = generate_tokens(text_file) - - lexicon = generate_lexicon(token_sym_table, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py deleted file mode 100755 index c8cf9b881..000000000 --- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon - -Lexicon = List[Tuple[str, List[str]]] - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") - return parser.parse_args() - - -def main(): - out_dir = Path(get_args().lang_dir) - lexicon_filename = out_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") - - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/local/prepare_words.py b/egs/aidatatang_200zh/ASR/local/prepare_words.py deleted file mode 100755 index 65aca2983..000000000 --- a/egs/aidatatang_200zh/ASR/local/prepare_words.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input words.txt without ids: - - words_no_ids.txt -and generates the new words.txt with related ids. - - words.txt -""" - - -import argparse -import logging - -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare words.txt", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - default="data/lang_char/words_no_ids.txt", - type=str, - help="the words file without ids for WenetSpeech", - ) - parser.add_argument( - "--output-file", - default="data/lang_char/words.txt", - type=str, - help="the words file with ids for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - add_words = [" 0", "!SIL 1", " 2", " 3"] - new_lines.extend(add_words) - - logging.info("Starting reading the input file") - for i in tqdm(range(len(lines))): - x = lines[i] - idx = 4 + i - new_line = str(x.strip("\n")) + " " + str(idx) - new_lines.append(new_line) - - logging.info("Starting writing the words.txt") - f_out = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_out.write(line) - f_out.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py deleted file mode 100755 index 74e025ad7..000000000 --- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py deleted file mode 100755 index 85047c367..000000000 --- a/egs/aidatatang_200zh/ASR/local/text2token.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) -# 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import codecs -import re -import sys -from typing import List - -from pypinyin import lazy_pinyin, pinyin - -is_python2 = sys.version_info[0] == 2 - - -def exist_or_not(i, match_pos): - start_pos = None - end_pos = None - for pos in match_pos: - if pos[0] <= i < pos[1]: - start_pos = pos[0] - end_pos = pos[1] - break - - return start_pos, end_pos - - -def get_parser(): - parser = argparse.ArgumentParser( - description="convert raw text to tokenized text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--nchar", - "-n", - default=1, - type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", - ) - parser.add_argument( - "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" - ) - parser.add_argument("--space", default="", type=str, help="space symbol") - parser.add_argument( - "--non-lang-syms", - "-l", - default=None, - type=str, - help="list of non-linguistic symobles, e.g., etc.", - ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") - parser.add_argument( - "--trans_type", - "-t", - type=str, - default="char", - choices=["char", "pinyin", "lazy_pinyin"], - help="""Transcript type. char/pinyin/lazy_pinyin""", - ) - return parser - - -def token2id( - texts, token_table, token_type: str = "lazy_pinyin", oov: str = "" -) -> List[List[int]]: - """Convert token to id. - Args: - texts: - The input texts, it refers to the chinese text here. - token_table: - The token table is built based on "data/lang_xxx/token.txt" - token_type: - The type of token, such as "pinyin" and "lazy_pinyin". - oov: - Out of vocabulary token. When a word(token) in the transcript - does not exist in the token list, it is replaced with `oov`. - - Returns: - The list of ids for the input texts. - """ - if texts is None: - raise ValueError("texts can't be None!") - else: - oov_id = token_table[oov] - ids: List[List[int]] = [] - for text in texts: - chars_list = list(str(text)) - if token_type == "lazy_pinyin": - text = lazy_pinyin(chars_list) - sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text - ] - ids.append(sub_ids) - else: # token_type = "pinyin" - text = pinyin(chars_list) - sub_ids = [ - token_table[txt[0]] if txt[0] in token_table else oov_id - for txt in text - ] - ids.append(sub_ids) - return ids - - -def main(): - parser = get_parser() - args = parser.parse_args() - - rs = [] - if args.non_lang_syms is not None: - with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: - nls = [x.rstrip() for x in f.readlines()] - rs = [re.compile(re.escape(x)) for x in nls] - - if args.text: - f = codecs.open(args.text, encoding="utf-8") - else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) - - sys.stdout = codecs.getwriter("utf-8")( - sys.stdout if is_python2 else sys.stdout.buffer - ) - line = f.readline() - n = args.nchar - while line: - x = line.split() - print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols :]) # noqa E203 - - # get all matched positions - match_pos = [] - for r in rs: - i = 0 - while i >= 0: - m = r.search(a, i) - if m: - match_pos.append([m.start(), m.end()]) - i = m.end() - else: - break - if len(match_pos) > 0: - chars = [] - i = 0 - while i < len(a): - start_pos, end_pos = exist_or_not(i, match_pos) - if start_pos is not None: - chars.append(a[start_pos:end_pos]) - i = end_pos - else: - chars.append(a[i]) - i += 1 - a = chars - - if args.trans_type == "pinyin": - a = pinyin(list(str(a))) - a = [one[0] for one in a] - - if args.trans_type == "lazy_pinyin": - a = lazy_pinyin(list(str(a))) - - a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203 - - a_flat = [] - for z in a: - a_flat.append("".join(z)) - - a_chars = "".join(a_flat) - print(a_chars) - line = f.readline() - - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh deleted file mode 100755 index 09dfd5fac..000000000 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 -perturb_speed=true - - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/aidatatang_200zh -# You can find "corpus" and "transcript" inside it. -# You can download it at https://openslr.org/62/ -# If you download the data by yourself, DON'T FORGET to extract the *.tar.gz files under corpus. - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - if [ ! -f $dl_dir/aidatatang_200zh/transcript/aidatatang_200_zh_transcript.txt ]; then - lhotse download aidatatang-200zh $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare aidatatang_200zh manifest" - # We assume that you have downloaded the aidatatang_200zh corpus - # to $dl_dir/aidatatang_200zh - if [ ! -f data/manifests/aidatatang_200zh/.manifests.done ]; then - mkdir -p data/manifests/aidatatang_200zh - lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh - touch data/manifests/aidatatang_200zh/.manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests/ - lhotse prepare musan $dl_dir/musan data/manifests/ - touch data/manifests/.manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for aidatatang_200zh" - if [ ! -f data/fbank/.aidatatang_200zh.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py --perturb-speed ${perturb_speed} - touch data/fbank/.aidatatang_200zh.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare char based lang" - lang_char_dir=data/lang_char - mkdir -p $lang_char_dir - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - # 2. chmod +x ./jq - # 3. cp jq /usr/bin - if [ ! -f $lang_char_dir/text ]; then - gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \ - |jq '.text' |sed -e 's/["text:\t ]*//g' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text - fi - # Prepare words.txt - if [ ! -f $lang_char_dir/text_words ]; then - gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \ - | jq '.text' | sed -e 's/["text:\t]*//g' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_words - fi - - cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ - | uniq > $lang_char_dir/words_no_ids.txt - - if [ ! -f $lang_char_dir/words.txt ]; then - ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt - fi - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py - fi -fi diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py deleted file mode 100644 index e29dd8ab5..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ /dev/null @@ -1,412 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - load_manifest, - load_manifest_lazy, - set_caching_enabled, -) -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - -set_caching_enabled(False) -torch.set_num_threads(1) - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class Aidatatang_200zhAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/dev/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=False, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - dev_iter_dataset = IterableDatasetWrapper( - dataset=validate, - sampler=valid_sampler, - ) - valid_dl = DataLoader( - dev_iter_dataset, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - test_iter_dataset = IterableDatasetWrapper( - dataset=test, - sampler=sampler, - ) - test_dl = DataLoader( - test_iter_dataset, - batch_size=None, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aidatatang_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aidatatang_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aidatatang_cuts_test.jsonl.gz" - ) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/conformer.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/conformer.py deleted file mode 120000 index a65957180..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py deleted file mode 100755 index 2512f233f..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1,536 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When training with the L subset, usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 6 \ - --avg 3 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 6 \ - --avg 3 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 6 \ - --avg 3 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import Aidatatang_200zhAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--batch", - type=int, - default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 50 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - Aidatatang_200zhAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - elif params.batch is not None: - filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt" - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints([filenames], device=device)) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) - - dev_cuts = aidatatang_200zh.valid_cuts() - test_cuts = aidatatang_200zh.test_cuts() - dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts) - test_dl = aidatatang_200zh.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decoder.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py deleted file mode 100755 index 5179bfa1c..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 29 \ - --avg 19 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless2/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/aidatatang_200zh/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt.", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - # Load tokens.txt here - token_table = k2.SymbolTable.from_file(params.tokens) - - # Load id of the token and the vocab size - # is defined in local/train_bpe_model.py - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 # +1 for - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/joiner.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py deleted file mode 120000 index b82e115fc..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/model.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/optim.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py deleted file mode 100644 index 17729e02e..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ /dev/null @@ -1,339 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - --max-sym-per-frame 1 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by -./pruned_transducer_stateless2/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --method is beam_search and modified_beam_search ", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - hyps = [] - msg = f"Using {params.decoding_method}" - logging.info(msg) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py deleted file mode 120000 index db93d155b..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py deleted file mode 100644 index fa809b768..000000000 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ /dev/null @@ -1,958 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1" - -./pruned_transducer_stateless2/train.py \ - --world-size 2 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 250 \ - --save-every-n 1000 - -# For mix precision training: - -./pruned_transducer_stateless2/train.py \ - --world-size 2 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 250 \ - --save-every-n 1000 - --use-fp16 True - -""" - -import argparse -import logging -import os -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import Aidatatang_200zhAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12359, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - Explanation of options saved in `params`: - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - best_train_epoch: It is the epoch that has the best training loss. - - best_valid_epoch: It is the epoch that has the best validation loss. - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - feature_dim: The model input dim. It has to match the one used - in computing features. - - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 10, - "log_interval": 1, - "reset_interval": 200, - "valid_interval": 400, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - # parameters for Noam - "model_warm_step": 200, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - - y = graph_compiler.texts_to_ids(texts) - if type(y) == list: - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) - - train_cuts = aidatatang_200zh.train_cuts() - valid_cuts = aidatatang_200zh.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 10.0 seconds - # - # Caution: There is a reason to select 10.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 10.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - valid_dl = aidatatang_200zh.valid_dataloaders(valid_cuts) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aidatatang_200zh.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - if not params.print_diagnostics and params.start_batch == 0: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - Aidatatang_200zhAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/shared b/egs/aidatatang_200zh/ASR/shared deleted file mode 120000 index 3a3b28f96..000000000 --- a/egs/aidatatang_200zh/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../egs/aishell/ASR/shared \ No newline at end of file diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md deleted file mode 100644 index d088072a7..000000000 --- a/egs/aishell/ASR/README.md +++ /dev/null @@ -1,35 +0,0 @@ - -# Introduction - -Please refer to for how to run models in this recipe. - -Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co., Ltd. -400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition. - -(From [Open Speech and Language Resources](https://www.openslr.org/33/)) - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------| -| `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` | -| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` | -| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data | -| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data| -| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 | -| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 | - - -The decoder in `transducer_stateless` is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. - -# Whisper - -Recipe to finetune large pretrained models -| | Encoder | Decoder | Comment | -|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------| -| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md deleted file mode 100644 index 355d1516d..000000000 --- a/egs/aishell/ASR/RESULTS.md +++ /dev/null @@ -1,943 +0,0 @@ -## Results - -### Aishell training results (Fine-tuning Pretrained Models) -#### Whisper -[./whisper](./whisper) -##### fine-tuning results on Aishell test set on whisper medium, large-v2, large-v3 - -| | test (before fine-tuning) | test (after fine-tuning) | comment | -|------------------------|------|------|-----------------------------------------| -| medium | 7.23 | 3.27 | --epoch 10 --avg 4, ddp | -| large-v2 | 6.56 | 2.47 | --epoch 10 --avg 6, deepspeed zero stage1 | -| large-v3 | 6.06 | 2.84 | --epoch 5 --avg 3, deepspeed zero stage1 | - -Command for training is: -```bash -pip install -r whisper/requirements.txt - -./prepare.sh --stage 30 --stop_stage 30 - -#fine-tuning with deepspeed zero stage 1 -torchrun --nproc-per-node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json - -# fine-tuning with ddp -torchrun --nproc-per-node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_medium \ - --base-lr 1e-5 \ - --model-name medium -``` - -Command for decoding using fine-tuned models: -```bash -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper -ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch 999 --avg 1 \ - --beam-size 10 --max-duration 50 -``` -Command for decoding using pretrained models (before fine-tuning): -```bash -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch -1 --avg 1 \ - --remove-whisper-encoder-input-length-restriction False \ - --beam-size 10 --max-duration 50 -``` -Fine-tuned models, training logs, decoding logs, tensorboard and decoding results -are available at - - -### Aishell training result (Stateless Transducer) - -#### Zipformer (Byte-level BPE) - -[./zipformer](./zipformer/) - -It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `vocab_size` set to 500. - -##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M - -| | test | dev | comment | -|------------------------|------|------|-----------------------------------------| -| greedy search | 4.54 | 4.31 | --epoch 40 --avg 10 | -| modified beam search | 4.37 | 4.11 | --epoch 40 --avg 10 | -| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | - -```bash -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1" - -./zipformer/train_bbpe.py \ - --world-size 2 \ - --num-epochs 40 \ - --start-epoch 1 \ - --use-fp16 1 \ - --context-size 2 \ - --enable-musan 0 \ - --exp-dir zipformer/exp_bbpe \ - --max-duration 1000 \ - --enable-musan 0 \ - --base-lr 0.045 \ - --lr-batches 7500 \ - --lr-epochs 10 \ - --spec-aug-time-warp-factor 20 -``` - -Command for decoding is: -```bash -for m in greedy_search modified_beam_search fast_beam_search ; do - ./zipformer/decode_bbpe.py \ - --epoch 40 \ - --avg 10 \ - --exp-dir ./zipformer_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --context-size 2 \ - --decoding-method $m -done -``` -Pretrained models, training logs, decoding logs, tensorboard and decoding results -are available at - - - -#### Zipformer (Non-streaming) - -[./zipformer](./zipformer/) - -It's reworked Zipformer with Pruned RNNT loss. -**Caution**: It uses `--context-size=1`. - -##### normal-scaled model, number of model parameters: 73412551, i.e., 73.41 M - -| | test | dev | comment | -|------------------------|------|------|-----------------------------------------| -| greedy search | 4.67 | 4.37 | --epoch 55 --avg 17 | -| modified beam search | 4.40 | 4.13 | --epoch 55 --avg 17 | -| fast beam search | 4.60 | 4.31 | --epoch 55 --avg 17 | - -Command for training is: -```bash -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1" - -./zipformer/train.py \ - --world-size 2 \ - --num-epochs 60 \ - --start-epoch 1 \ - --use-fp16 1 \ - --context-size 1 \ - --enable-musan 0 \ - --exp-dir zipformer/exp \ - --max-duration 1000 \ - --enable-musan 0 \ - --base-lr 0.045 \ - --lr-batches 7500 \ - --lr-epochs 18 \ - --spec-aug-time-warp-factor 20 -``` - -Command for decoding is: -```bash -for m in greedy_search modified_beam_search fast_beam_search ; do - ./zipformer/decode.py \ - --epoch 55 \ - --avg 17 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --context-size 1 \ - --decoding-method $m -done -``` -Pretrained models, training logs, decoding logs, tensorboard and decoding results -are available at - - - -##### small-scaled model, number of model parameters: 30167139, i.e., 30.17 M - -| | test | dev | comment | -|------------------------|------|------|-----------------------------------------| -| greedy search | 4.97 | 4.67 | --epoch 55 --avg 21 | -| modified beam search | 4.67 | 4.40 | --epoch 55 --avg 21 | -| fast beam search | 4.85 | 4.61 | --epoch 55 --avg 21 | - -Command for training is: -```bash -export CUDA_VISIBLE_DEVICES="0,1" - -./zipformer/train.py \ - --world-size 2 \ - --num-epochs 60 \ - --start-epoch 1 \ - --use-fp16 1 \ - --context-size 1 \ - --exp-dir zipformer/exp-small \ - --enable-musan 0 \ - --base-lr 0.045 \ - --lr-batches 7500 \ - --lr-epochs 18 \ - --spec-aug-time-warp-factor 20 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,768,768,768,768 \ - --encoder-dim 192,256,256,256,256,256 \ - --encoder-unmasked-dim 192,192,192,192,192,192 \ - --max-duration 1200 -``` - -Command for decoding is: -```bash -for m in greedy_search modified_beam_search fast_beam_search ; do - ./zipformer/decode.py \ - --epoch 55 \ - --avg 21 \ - --exp-dir ./zipformer/exp-small \ - --lang-dir data/lang_char \ - --context-size 1 \ - --decoding-method $m \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,768,768,768,768 \ - --encoder-dim 192,256,256,256,256,256 \ - --encoder-unmasked-dim 192,192,192,192,192,192 -done -``` - -Pretrained models, training logs, decoding logs, tensorboard and decoding results -are available at - - -##### large-scaled model, number of model parameters: 157285130, i.e., 157.29 M - -| | test | dev | comment | -|------------------------|------|------|-----------------------------------------| -| greedy search | 4.49 | 4.22 | --epoch 56 --avg 23 | -| modified beam search | 4.28 | 4.03 | --epoch 56 --avg 23 | -| fast beam search | 4.44 | 4.18 | --epoch 56 --avg 23 | - -Command for training is: -```bash -export CUDA_VISIBLE_DEVICES="0,1" - -./zipformer/train.py \ - --world-size 2 \ - --num-epochs 60 \ - --use-fp16 1 \ - --context-size 1 \ - --exp-dir ./zipformer/exp-large \ - --enable-musan 0 \ - --lr-batches 7500 \ - --lr-epochs 18 \ - --spec-aug-time-warp-factor 20 \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --max-duration 800 -``` - -Command for decoding is: -```bash -for m in greedy_search modified_beam_search fast_beam_search ; do - ./zipformer/decode.py \ - --epoch 56 \ - --avg 23 \ - --exp-dir ./zipformer/exp-large \ - --lang-dir data/lang_char \ - --context-size 1 \ - --decoding-method $m \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 -done -``` - -Pretrained models, training logs, decoding logs, tensorboard and decoding results -are available at - - -#### Pruned transducer stateless 7 streaming -[./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) - -It's Streaming version of Zipformer1 with Pruned RNNT loss. - -| | test | dev | comment | -|------------------------|------|------|---------------------------------------| -| greedy search | 6.95 | 6.29 | --epoch 44 --avg 15 --max-duration 600 | -| modified beam search | 6.51 | 5.90 | --epoch 44 --avg 15 --max-duration 600 | -| fast beam search | 6.73 | 6.09 | --epoch 44 --avg 15 --max-duration 600 | - -Training command is: - -```bash -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 2 \ - --num-epochs 50 \ - --use-fp16 1 \ - --context-size 1 \ - --max-duration 800 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --enable-musan 0 \ - --spec-aug-time-warp-factor 20 -``` - -**Caution**: It uses `--context-size=1`. - -The decoding command is: -```bash -for m in greedy_search modified_beam_search fast_beam_search ; do - ./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 44 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --lang-dir data/lang_char \ - --context-size 1 \ - --decoding-method $m -done -``` - -Pretrained models, training logs, decoding logs, tensorboard and decoding results -are available at - - - - -#### Pruned transducer stateless 7 - -[./pruned_transducer_stateless7](./pruned_transducer_stateless7) - -It's Zipformer with Pruned RNNT loss. - -| | test | dev | comment | -|------------------------|------|------|---------------------------------------| -| greedy search | 5.02 | 4.61 | --epoch 42 --avg 6 --max-duration 600 | -| modified beam search | 4.81 | 4.4 | --epoch 42 --avg 6 --max-duration 600 | -| fast beam search | 4.91 | 4.52 | --epoch 42 --avg 6 --max-duration 600 | - -Training command is: - -```bash -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1" - -./pruned_transducer_stateless7/train.py \ - --world-size 2 \ - --num-epochs 50 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --context-size 1 \ - --max-duration 300 -``` - -**Caution**: It uses `--context-size=1`. - -The tensorboard log is available at - - -The decoding command is: -```bash -for m in greedy_search modified_beam_search fast_beam_search ; do - ./pruned_transducer_stateless7/decode.py \ - --epoch 42 \ - --avg 6 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --max-duration 300 \ - --context-size 1 \ - --decoding-method $m - -done -``` - -Pretrained models, training logs, decoding logs, and decoding results -are available at - -#### Pruned transducer stateless 7 (Byte-level BPE) - -See - -[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe) - -**Note**: The modeling units are byte level BPEs - -The best results I have gotten are: - -Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments --- | -- | -- | -- | -- | -- -500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29 - -The training command: - -``` -export CUDA_VISIBLE_DEVICES="4,5,6,7" - -./pruned_transducer_stateless7_bbpe/train.py \ - --world-size 4 \ - --num-epochs 50 \ - --start-epoch 1 \ - --use-fp16 1 \ - --max-duration 800 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --exp-dir pruned_transducer_stateless7_bbpe/exp \ - --lr-epochs 6 \ - --master-port 12535 -``` - -The decoding command: - -``` -for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do - ./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 48 \ - --avg 29 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-sym-per-frame 1 \ - --ngram-lm-scale 0.25 \ - --ilme-scale 0.2 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --max-duration 2000 \ - --decoding-method $m -done -``` - -The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe - - -#### Pruned transducer stateless 3 - -See - - -[./pruned_transducer_stateless3](./pruned_transducer_stateless3) - -It uses pruned RNN-T. - -| | test | dev | comment | -|------------------------|------|------|---------------------------------------| -| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 | -| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 | -| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 | -| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 | -| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 | - -Training command is: - -```bash -./prepare.sh -./prepare_aidatatang_200zh.sh - -export CUDA_VISIBLE_DEVICES="4,5,6,7" - -./pruned_transducer_stateless3/train.py \ - --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \ - --world-size 4 \ - --max-duration 200 \ - --datatang-prob 0.5 \ - --start-epoch 1 \ - --num-epochs 30 \ - --use-fp16 1 \ - --num-encoder-layers 12 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --context-size 1 \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --master-port 12356 -``` - -**Caution**: It uses `--context-size=1`. - -The tensorboard log is available at - - -The decoding command is: - -```bash -for epoch in 29; do - for avg in 5; do - for m in greedy_search modified_beam_search fast_beam_search; do - ./pruned_transducer_stateless3/decode.py \ - --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \ - --epoch $epoch \ - --avg $avg \ - --use-averaged-model 1 \ - --max-duration 600 \ - --decoding-method $m \ - --num-encoder-layers 12 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --context-size 1 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 - done - done -done -``` - -We provide the option of shallow fusion with a RNN language model. The pre-trained language model is -available at . To decode with the language model, -please use the following command: - -```bash -# download pre-trained model -git lfs install -git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 - -aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/ - -pushd ${aishell_exp}/exp -ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt -popd - -# download RNN LM -git lfs install -git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm -rnnlm_dir=icefall-aishell-rnn-lm - -# RNNLM shallow fusion -for lm_scale in $(seq 0.26 0.02 0.34); do - python ./pruned_transducer_stateless3/decode.py \ - --epoch 99 \ - --avg 1 \ - --lang-dir ${aishell_exp}/data/lang_char \ - --exp-dir ${aishell_exp}/exp \ - --use-averaged-model False \ - --decoding-method modified_beam_search_lm_shallow_fusion \ - --use-shallow-fusion 1 \ - --lm-type rnn \ - --lm-exp-dir ${rnnlm_dir}/exp \ - --lm-epoch 99 \ - --lm-scale $lm_scale \ - --lm-avg 1 \ - --rnn-lm-embedding-dim 2048 \ - --rnn-lm-hidden-dim 2048 \ - --rnn-lm-num-layers 2 \ - --lm-vocab-size 4336 -done - -# RNNLM Low-order density ratio (LODR) with a 2-gram - -cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt - -for lm_scale in 0.48; do - for LODR_scale in -0.28; do - python ./pruned_transducer_stateless3/decode.py \ - --epoch 99 \ - --avg 1 \ - --lang-dir ${aishell_exp}/data/lang_char \ - --exp-dir ${aishell_exp}/exp \ - --use-averaged-model False \ - --decoding-method modified_beam_search_LODR \ - --use-shallow-fusion 1 \ - --lm-type rnn \ - --lm-exp-dir ${rnnlm_dir}/exp \ - --lm-epoch 99 \ - --lm-scale $lm_scale \ - --lm-avg 1 \ - --rnn-lm-embedding-dim 2048 \ - --rnn-lm-hidden-dim 2048 \ - --rnn-lm-num-layers 2 \ - --lm-vocab-size 4336 \ - --tokens-ngram 2 \ - --backoff-id 4336 \ - --ngram-lm-scale $LODR_scale - done -done - -``` - -Pretrained models, training logs, decoding logs, and decoding results -are available at - - -We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how -to use the pre-trained model for non-streaming ASR. See - - - -#### Pruned transducer stateless 2 - -See https://github.com/k2-fsa/icefall/pull/536 - -[./pruned_transducer_stateless2](./pruned_transducer_stateless2) - -It uses pruned RNN-T. - -| | test | dev | comment | -| -------------------- | ---- | ---- | -------------------------------------- | -| greedy search | 5.20 | 4.78 | --epoch 72 --avg 14 --max-duration 200 | -| modified beam search | 5.07 | 4.63 | --epoch 72 --avg 14 --max-duration 200 | -| fast beam search | 5.13 | 4.70 | --epoch 72 --avg 14 --max-duration 200 | - -Training command is: - -```bash -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1" - -./pruned_transducer_stateless2/train.py \ - --world-size 2 \ - --num-epochs 90 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --max-duration 200 \ -``` - -The tensorboard log is available at -https://tensorboard.dev/experiment/QI3PVzrGRrebxpbWUPwmkA/ - -The decoding command is: -```bash -for m in greedy_search modified_beam_search fast_beam_search ; do - ./pruned_transducer_stateless2/decode.py \ - --epoch 72 \ - --avg 14 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 200 \ - --decoding-method $m - -done -``` - -Pretrained models, training logs, decoding logs, and decoding results -are available at - - - -#### 2022-03-01 - -[./transducer_stateless_modified-2](./transducer_stateless_modified-2) - -It uses [optimized_transducer](https://github.com/csukuangfj/optimized_transducer) -for computing RNN-T loss. - -Stateless transducer + modified transducer + using [aidatatang_200zh](http://www.openslr.org/62/) as extra training data. - - -| | test |comment | -|------------------------|------|----------------------------------------------------------------| -| greedy search | 4.94 |--epoch 89, --avg 38, --max-duration 100, --max-sym-per-frame 1 | -| modified beam search | 4.68 |--epoch 89, --avg 38, --max-duration 100 --beam-size 4 | - -The training commands are: - -```bash -cd egs/aishell/ASR -./prepare.sh --stop-stage 6 -./prepare_aidatatang_200zh.sh - -export CUDA_VISIBLE_DEVICES="0,1,2" - -./transducer_stateless_modified-2/train.py \ - --world-size 3 \ - --num-epochs 90 \ - --start-epoch 0 \ - --exp-dir transducer_stateless_modified-2/exp-2 \ - --max-duration 250 \ - --lr-factor 2.0 \ - --context-size 2 \ - --modified-transducer-prob 0.25 \ - --datatang-prob 0.2 -``` - -The tensorboard log is available at - - -The commands for decoding are - -```bash -# greedy search -for epoch in 89; do - for avg in 38; do - ./transducer_stateless_modified-2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless_modified-2/exp-2 \ - --max-duration 100 \ - --context-size 2 \ - --decoding-method greedy_search \ - --max-sym-per-frame 1 - done -done - -# modified beam search -for epoch in 89; do - for avg in 38; do - ./transducer_stateless_modified-2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless_modified-2/exp-2 \ - --max-duration 100 \ - --context-size 2 \ - --decoding-method modified_beam_search \ - --beam-size 4 - done -done -``` - -You can find a pre-trained model, decoding logs, and decoding results at - - -#### 2022-03-01 - -[./transducer_stateless_modified](./transducer_stateless_modified) - -Stateless transducer + modified transducer. - -| | test |comment | -|------------------------|------|----------------------------------------------------------------| -| greedy search | 5.22 |--epoch 64, --avg 33, --max-duration 100, --max-sym-per-frame 1 | -| modified beam search | 5.02 |--epoch 64, --avg 33, --max-duration 100 --beam-size 4 | - -The training commands are: - -```bash -cd egs/aishell/ASR -./prepare.sh --stop-stage 6 - -export CUDA_VISIBLE_DEVICES="0,1,2" - -./transducer_stateless_modified/train.py \ - --world-size 3 \ - --num-epochs 90 \ - --start-epoch 0 \ - --exp-dir transducer_stateless_modified/exp-4 \ - --max-duration 250 \ - --lr-factor 2.0 \ - --context-size 2 \ - --modified-transducer-prob 0.25 -``` - -The tensorboard log is available at - - -The commands for decoding are - -```bash -# greedy search -for epoch in 64; do - for avg in 33; do - ./transducer_stateless_modified/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless_modified/exp-4 \ - --max-duration 100 \ - --context-size 2 \ - --decoding-method greedy_search \ - --max-sym-per-frame 1 - done -done - -# modified beam search -for epoch in 64; do - for avg in 33; do - ./transducer_stateless_modified/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless_modified/exp-4 \ - --max-duration 100 \ - --context-size 2 \ - --decoding-method modified_beam_search \ - --beam-size 4 - done -done -``` - -You can find a pre-trained model, decoding logs, and decoding results at - - - -#### 2022-2-19 -(Duo Ma): The tensorboard log for training is available at https://tensorboard.dev/experiment/25PmX3MxSVGTdvIdhOwllw/#scalars -You can find a pretrained model by visiting https://huggingface.co/shuanguanma/icefall_aishell_transducer_stateless_context_size2_epoch60_2022_2_19 -| | test |comment | -|---------------------------|------|-----------------------------------------| -| greedy search | 5.4 |--epoch 59, --avg 10, --max-duration 100 | -| beam search | 5.05|--epoch 59, --avg 10, --max-duration 100 | - -You can use the following commands to reproduce our results: - -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" -python3 ./transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 60 \ - --start-epoch 0 \ - --exp-dir exp/transducer_stateless_context_size2 \ - --max-duration 100 \ - --lr-factor 2.5 \ - --context-size 2 - -lang_dir=data/lang_char -dir=exp/transducer_stateless_context_size2 -python3 ./transducer_stateless/decode.py \ - --epoch 59 \ - --avg 10 \ - --exp-dir $dir \ - --lang-dir $lang_dir \ - --decoding-method greedy_search \ - --context-size 2 \ - --max-sym-per-frame 3 - -lang_dir=data/lang_char -dir=exp/transducer_stateless_context_size2 -python3 ./transducer_stateless/decode.py \ - --epoch 59 \ - --avg 10 \ - --exp-dir $dir \ - --lang-dir $lang_dir \ - --decoding-method beam_search \ - --context-size 2 \ - --max-sym-per-frame 3 -``` - -#### 2022-02-18 -(Pingfeng Luo) : The tensorboard log for training is available at -And pretrained model is available at - -||test| -|--|--| -|CER| 5.05% | - -You can use the following commands to reproduce our results: - -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8" -./transducer_stateless/train.py \ - --bucketing-sampler True \ - --world-size 8 \ - --lang-dir data/lang_char \ - --num-epochs 60 \ - --start-epoch 0 \ - --exp-dir transducer_stateless/exp_rnnt_k2 \ - --max-duration 80 \ - --lr-factor 3 - -./transducer_stateless/decode.py \ - --epoch 59 \ - --avg 10 \ - --lang-dir data/lang_char \ - --exp-dir transducer_stateless/exp_rnnt_k2 \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 -``` - -### Aishell training results (Conformer-MMI) -#### 2021-12-04 -(Pingfeng Luo): Result of - -The tensorboard log for training is available at - -And pretrained model is available at - -The best decoding results (CER) are listed below, we got this results by averaging models from epoch 61 to 85, and using `attention-decoder` decoder with num_paths equals to 100. - -||test| -|--|--| -|CER| 4.94% | - -||lm_scale|attention_scale| -|--|--|--| -|test|1.1|0.3| - -You can use the following commands to reproduce our results: - -```bash -git clone https://github.com/k2-fsa/icefall -cd icefall - -cd egs/aishell/ASR -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8" -python conformer_mmi/train.py --bucketing-sampler True \ - --max-duration 200 \ - --start-epoch 0 \ - --num-epochs 90 \ - --world-size 8 - -python conformer_mmi/decode.py --nbest-scale 0.5 \ - --epoch 85 \ - --avg 25 \ - --method attention-decoder \ - --max-duration 20 \ - --num-paths 100 -``` - -### Aishell training results (Conformer-CTC) -#### 2021-11-16 -(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30 - -Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc - -The best decoding results (CER) are listed below, we got this results by averaging models from epoch 60 to 84, and using `attention-decoder` decoder with num_paths equals to 100. - -||test| -|--|--| -|CER| 4.26% | - -To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the CER above are also listed below. - -||lm_scale|attention_scale| -|--|--|--| -|test|0.3|0.9| - -You can use the following commands to reproduce our results: - -```bash -git clone https://github.com/k2-fsa/icefall -cd icefall - -cd egs/aishell/ASR -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1,2,3" -python conformer_ctc/train.py --bucketing-sampler True \ - --max-duration 200 \ - --start-epoch 0 \ - --num-epochs 90 \ - --world-size 4 - -python conformer_ctc/decode.py --nbest-scale 0.5 \ - --epoch 84 \ - --avg 25 \ - --method attention-decoder \ - --max-duration 20 \ - --num-paths 100 -``` - -### Aishell training results (Tdnn-Lstm) -#### 2021-09-13 - -(Wei Kang): Result of phone based Tdnn-Lstm model, https://github.com/k2-fsa/icefall/pull/30 - -Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc_lstm_ctc - -The best decoding results (CER) are listed below, we got this results by averaging models from epoch 19 to 8, and using `1best` decoding method. - -||test| -|--|--| -|CER| 10.16% | diff --git a/egs/aishell/ASR/conformer_ctc/README.md b/egs/aishell/ASR/conformer_ctc/README.md deleted file mode 100644 index 41637159d..000000000 --- a/egs/aishell/ASR/conformer_ctc/README.md +++ /dev/null @@ -1,4 +0,0 @@ - -Please visit - -for how to run this recipe. diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py deleted file mode 100644 index ab1cbbae4..000000000 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ /dev/null @@ -1,895 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -import warnings -from typing import Optional, Tuple - -import torch -from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask - - -class Conformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - use_feat_batchnorm(bool): whether to use batch-normalize the input. - """ - - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, - ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - normalize_before, - ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity - - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: - The model input. Its shape is [N, T, C]. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - - Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - - if self.normalize_before: - x = self.after_norm(x) - - return x, mask - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block - - self.dropout = nn.Dropout(dropout) - - self.normalize_before = normalize_before - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) - if not self.normalize_before: - src = self.norm_ff_macaron(src) - - # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) - - # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout( - self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - ) - if not self.normalize_before: - src = self.norm_conv(src) - - # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) - - if self.normalize_before: - src = self.norm_final(src) - - return src - - -class ConformerEncoder(nn.TransformerEncoder): - r"""ConformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - output = src - - for mod in self.layers: - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - - """ - self.extend_pe(x) - x = x * self.xscale - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.weight, - self.in_proj.bias, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def identity(x): - return x diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py deleted file mode 100755 index 2cb476e20..000000000 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ /dev/null @@ -1,572 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from conformer import Conformer - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, -) -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="attention-decoder", - help="""Decoding method. - Supported values are: - - (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to - tokens using token symbol tabel directly. - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) attention-decoder. Extract n paths from the lattice, - the path with the highest score is the decoding result. - - (4) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, attention-decoder, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, attention-decoder, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The LM dir. - It should contain either G_3_gram.pt or G_3_gram.fst.txt - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "subsampling_factor": 4, - "feature_dim": 80, - "nhead": 4, - "attention_dim": 512, - "num_encoder_layers": 12, - "num_decoder_layers": 6, - "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "search_beam": 20, - "output_beam": 7, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - batch: dict, - lexicon: Lexicon, - sos_id: int, - eos_id: int, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if decoding method is 1best, the key is the string `no_rescore`. - If attention rescoring is used, the key is the string - `ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the - value of `lm_scale` and `attention_scale`. An example key is - `ngram_lm_scale_0.7_attention_scale_0.5` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "attention-decoder", it uses attention rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains the token symbol table and the word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - - nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - decoding_graph = H - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - key = "ctc-decoding" - hyps = [[lexicon.token_table[i] for i in ids] for ids in token_ids] - return {key: hyps} - - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=lexicon.word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method == "attention-decoder" - - best_path_dict = rescore_with_attention_decoder( - lattice=lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - lexicon: Lexicon, - sos_id: int, - eos_id: int, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - lexicon: - It contains the token symbol table and the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - Return a dict, whose key may be "no-rescore" if the decoding method is - 1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention - rescoring is used. Its value is a list of tuples. Each tuple contains two - elements: The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - batch=batch, - lexicon=lexicon, - sos_id=sos_id, - eos_id=eos_id, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - if params.method == "attention-decoder": - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=enable_log, - compute_CER=True, - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - sos_token="", - eos_token="", - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id - - if params.method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - else: - H = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_encoder_layers=params.num_encoder_layers, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - - test_sets = ["test"] - test_dls = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - lexicon=lexicon, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py deleted file mode 120000 index 896b78aef..000000000 --- a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py deleted file mode 120000 index aa1b6073d..000000000 --- a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py deleted file mode 120000 index 0cf42ce30..000000000 --- a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/aishell/ASR/conformer_ctc/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py deleted file mode 100755 index af1171a6f..000000000 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ /dev/null @@ -1,375 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from conformer import Conformer -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_attention_decoder, -) -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens-file", - type=str, - help="Path to tokens.txt" "Used only when method is ctc-decoding", - ) - - parser.add_argument( - "--words-file", - type=str, - help="Path to words.txt" "Used when method is NOT ctc-decoding", - ) - - parser.add_argument( - "--HLG", - type=str, - help="Path to HLG.pt." "Used when method is NOT ctc-decoding", - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use ctc decoding. It maps the tokens ids to tokens - using the token symbol table directly. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) attention-decoder - Extract n paths from the rescored - lattice and use the transformer attention decoder for - rescoring. - We call it HLG decoding + n-gram LM rescoring + attention - decoder rescoring. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.3, - help=""" - Used only when method is attention-decoder. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--attention-decoder-scale", - type=float, - default=0.9, - help=""" - Used only when method is attention-decoder. - It specifies the scale for attention decoder scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help=""" - Used only when method is attention-decoder. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the SOS token. - """, - ) - - parser.add_argument( - "--eos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the EOS token. - """, - ) - - parser.add_argument( - "--num_classes", - type=int, - default=4336, - help="The Vocab size.", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "subsampling_factor": 4, - "feature_dim": 80, - "nhead": 4, - "attention_dim": 512, - "num_decoder_layers": 6, - "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for deocding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - if args.method != "attention-decoder": - # to save memory as the attention decoder - # will not be used - params.num_decoder_layers = 0 - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output, memory, memory_key_padding_mask = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - token_sym_table = k2.SymbolTable.from_file(params.tokens_file) - max_token_id = params.num_classes - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=True, - device=device, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = [[token_sym_table[i] for i in ids] for ids in token_ids] - elif params.method in ["1best", "attention-decoder"]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "attention-decoder": - logging.info("Use HLG + attention decoder rescoring") - best_path_dict = rescore_with_attention_decoder( - lattice=lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py deleted file mode 100755 index 81fa234dd..000000000 --- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from subsampling import Conv2dSubsampling, VggSubsampling - - -def test_conv2d_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = Conv2dSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim - - -def test_vgg_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = VggSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim diff --git a/egs/aishell/ASR/conformer_ctc/test_transformer.py b/egs/aishell/ASR/conformer_ctc/test_transformer.py deleted file mode 100755 index 7c0695683..000000000 --- a/egs/aishell/ASR/conformer_ctc/test_transformer.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from torch.nn.utils.rnn import pad_sequence -from transformer import ( - Transformer, - add_eos, - add_sos, - decoder_padding_mask, - encoder_padding_mask, - generate_square_subsequent_mask, -) - - -def test_encoder_padding_mask(): - supervisions = { - "sequence_idx": torch.tensor([0, 1, 2]), - "start_frame": torch.tensor([0, 0, 0]), - "num_frames": torch.tensor([18, 7, 13]), - } - - max_len = ((18 - 1) // 2 - 1) // 2 - mask = encoder_padding_mask(max_len, supervisions) - expected_mask = torch.tensor( - [ - [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, - [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, - [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_transformer(): - num_features = 40 - num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) - - N = 31 - - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) - - -def test_generate_square_subsequent_mask(): - s = 5 - mask = generate_square_subsequent_mask(s) - inf = float("inf") - expected_mask = torch.tensor( - [ - [0.0, -inf, -inf, -inf, -inf], - [0.0, 0.0, -inf, -inf, -inf], - [0.0, 0.0, 0.0, -inf, -inf], - [0.0, 0.0, 0.0, 0.0, -inf], - [0.0, 0.0, 0.0, 0.0, 0.0], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_decoder_padding_mask(): - x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] - y = pad_sequence(x, batch_first=True, padding_value=-1) - mask = decoder_padding_mask(y, ignore_id=-1) - expected_mask = torch.tensor( - [[False, False, True], [False, True, True], [False, False, False]] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_add_sos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_sos(x, sos_id=0) - expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] - assert y == expected_y - - -def test_add_eos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_eos(x, eos_id=0) - expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] - assert y == expected_y diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py deleted file mode 100755 index c2cbe6e3b..000000000 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ /dev/null @@ -1,681 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from conformer import Conformer -from lhotse.utils import fix_random_seed -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=90, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - conformer_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.7, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - subsampling_factor: The subsampling factor for the model. - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - attention_dim: Attention dimension. - - - nhead: Number of heads in multi-head attention. - Must satisfy attention_dim // nhead == 0. - - - num_encoder_layers: Number of attention encoder layers. - - - num_decoder_layers: Number of attention decoder layers. - - - use_feat_batchnorm: Whether to do normalization in the input layer. - - - weight_decay: The weight_decay for the optimizer. - - - lr_factor: The lr_factor for the optimizer. - - - warm_step: The warm_step for the optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 2000, - # parameters for k2.ctc_loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for conformer - "subsampling_factor": 4, - "feature_dim": 80, - "attention_dim": 512, - "nhead": 4, - "num_encoder_layers": 12, - "num_decoder_layers": 6, - "use_feat_batchnorm": True, - # parameters for Noam - "weight_decay": 1e-5, - "lr_factor": 5.0, - "warm_step": 36000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CharCtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[torch.Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - sos_token="", - eos_token="", - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_encoder_layers=params.num_encoder_layers, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py deleted file mode 100644 index a3e50e385..000000000 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ /dev/null @@ -1,924 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling, VggSubsampling -from torch.nn.utils.rnn import pad_sequence - -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] - - -class Transformer(nn.Module): - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder/decoder. - num_encoder_layers: - Number of encoder layers. - num_decoder_layers: - Number of decoder layers. - dropout: - Dropout in encoder/decoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. - """ - super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) - - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_classes) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None - - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) - ) - - if num_decoder_layers > 0: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol - - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None - - self.decoder = nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) - - self.decoder_criterion = LabelSmoothingLoss() - else: - self.decoder_criterion = None - - def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C). - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is (N, T, C) - - Encoder output with shape (T, N, C). It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is (N, T). - It is None if `supervision` is None. - """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - - def run_encoder( - self, x: torch.Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - Returns: - Return a tuple with two tensors: - - The encoder output, with shape (T, N, C) - - encoder padding mask, with shape (N, T). - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - - return x, mask - - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The output tensor from the transformer encoder. - Its shape is (T, N, C) - - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is (N, T, C) - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - - @torch.jit.export - def decoder_forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - @torch.jit.export - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[torch.Tensor], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - if isinstance(token_ids[0], torch.Tensor): - # This branch is executed by torchscript in C++. - # See https://github.com/k2-fsa/k2/pull/870 - # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 - token_ids = [tolist(t) for t in token_ids] - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, - reduction="none", - ) - - nll = nll.view(pred_pad.shape[0], -1) - - return nll - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -class TransformerDecoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerDecoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - tgt: - the sequence to the decoder layer (required). - memory: - the sequence from the last layer of the encoder (required). - tgt_mask: - the mask for the tgt sequence (optional). - memory_mask: - the mask for the memory sequence (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch (optional). - memory_key_padding_mask: - the mask for the memory keys per batch (optional). - - Shape: - tgt: (T, N, E). - memory: (S, N, E). - tgt_mask: (T, T). - memory_mask: (T, S). - tgt_key_padding_mask: (N, T). - memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - tgt2 = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, - )[0] - tgt = residual + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm2(tgt) - tgt2 = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = residual + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = residual + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - return tgt - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - - -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO:: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), - True denote the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask - - -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - return [[sos_id] + utt for utt in token_ids] - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - return [utt + [eos_id] for utt in token_ids] - - -def tolist(t: torch.Tensor) -> List[int]: - """Used by jit""" - return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/aishell/ASR/conformer_mmi/asr_datamodule.py b/egs/aishell/ASR/conformer_mmi/asr_datamodule.py deleted file mode 120000 index a73848de9..000000000 --- a/egs/aishell/ASR/conformer_mmi/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py deleted file mode 100644 index ab1cbbae4..000000000 --- a/egs/aishell/ASR/conformer_mmi/conformer.py +++ /dev/null @@ -1,895 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -import warnings -from typing import Optional, Tuple - -import torch -from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask - - -class Conformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - use_feat_batchnorm(bool): whether to use batch-normalize the input. - """ - - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, - ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - normalize_before, - ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity - - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: - The model input. Its shape is [N, T, C]. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - - Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - - if self.normalize_before: - x = self.after_norm(x) - - return x, mask - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block - - self.dropout = nn.Dropout(dropout) - - self.normalize_before = normalize_before - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) - if not self.normalize_before: - src = self.norm_ff_macaron(src) - - # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) - - # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout( - self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - ) - if not self.normalize_before: - src = self.norm_conv(src) - - # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) - - if self.normalize_before: - src = self.norm_final(src) - - return src - - -class ConformerEncoder(nn.TransformerEncoder): - r"""ConformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - output = src - - for mod in self.layers: - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - - """ - self.extend_pe(x) - x = x * self.xscale - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.weight, - self.in_proj.bias, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def identity(x): - return x diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py deleted file mode 100755 index 8a2daa93e..000000000 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ /dev/null @@ -1,589 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) -# Copyright 2021 Pingfeng Luo -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from conformer import Conformer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, -) -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="attention-decoder", - help="""Decoding method. - Supported values are: - - (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to - tokens using token symbol tabel directly. - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) attention-decoder. Extract n paths from the lattice, - the path with the highest score is the decoding result. - - (4) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, attention-decoder, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, attention-decoder, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_mmi/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_phone", - help="The lang dir", - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The LM dir. - It should contain either G_3_gram.pt or G_3_gram.fst.txt - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "subsampling_factor": 4, - "feature_dim": 80, - "nhead": 4, - "attention_dim": 512, - "num_encoder_layers": 12, - "num_decoder_layers": 6, - "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "search_beam": 20, - "output_beam": 7, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - batch: dict, - lexicon: Lexicon, - sos_id: int, - eos_id: int, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if decoding method is 1best, the key is the string `no_rescore`. - If attention rescoring is used, the key is the string - `ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the - value of `lm_scale` and `attention_scale`. An example key is - `ngram_lm_scale_0.7_attention_scale_0.5` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "attention-decoder", it uses attention rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains the token symbol table and the word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - - nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - decoding_graph = H - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - key = "ctc-decoding" - hyps = [[lexicon.token_table[i] for i in ids] for ids in token_ids] - return {key: hyps} - - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=lexicon.word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method == "attention-decoder" - - best_path_dict = rescore_with_attention_decoder( - lattice=lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - lexicon: Lexicon, - sos_id: int, - eos_id: int, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - lexicon: - It contains the token symbol table and the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - Return a dict, whose key may be "no-rescore" if the decoding method is - 1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention - rescoring is used. Its value is a list of tuples. Each tuple contains two - elements: The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - batch=batch, - lexicon=lexicon, - sos_id=sos_id, - eos_id=eos_id, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - if params.method == "attention-decoder": - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=enable_log, - compute_CER=True, - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - graph_compiler = MmiTrainingGraphCompiler( - args.lang_dir, - device=device, - oov="", - sos_id=1, - eos_id=1, - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id - - if params.method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - else: - H = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_encoder_layers=params.num_encoder_layers, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") - return - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - - test_sets = ["test"] - test_dls = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - lexicon=lexicon, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py deleted file mode 120000 index 08734abd7..000000000 --- a/egs/aishell/ASR/conformer_mmi/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py deleted file mode 100644 index 398837a46..000000000 --- a/egs/aishell/ASR/conformer_mmi/subsampling.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__(self, idim: int, odim: int) -> None: - """ - Args: - idim: - Input dim. The input shape is [N, T, idim]. - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] - """ - assert idim >= 7 - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), - nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), - nn.ReLU(), - ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is [N, T, idim]. - - Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] - """ - # On entry, x is [N, T, idim] - x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] - x = self.conv(x) - # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] - return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described in the following paper: - https://arxiv.org/pdf/1910.09799.pdf - - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where - T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. - - This uses 2 VGG blocks with 2 Conv2d layers each, - subsampling its input by a factor of 4 in the time dimensions. - - Args: - idim: - Input dim. The input shape is [N, T, idim]. - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] - """ - super().__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the - # 2nd convolution, so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is [N, T, idim]. - - Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] - """ - x = x.unsqueeze(1) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py deleted file mode 100755 index 09cd6e60c..000000000 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ /dev/null @@ -1,678 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# Copyright 2021 Pingfeng Luo -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from conformer import Conformer -from lhotse.utils import fix_random_seed -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.mmi import LFMMILoss -from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=90, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - conformer_mmi/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_mmi/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_phone", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.7, - help="""The attention rate. - The total loss is (1 - att_rate) * mmi_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - subsampling_factor: The subsampling factor for the model. - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - attention_dim: Attention dimension. - - - nhead: Number of heads in multi-head attention. - Must satisfy attention_dim // nhead == 0. - - - num_encoder_layers: Number of attention encoder layers. - - - num_decoder_layers: Number of attention decoder layers. - - - use_feat_batchnorm: Whether to do normalization in the input layer. - - - weight_decay: The weight_decay for the optimizer. - - - lr_factor: The lr_factor for the optimizer. - - - warm_step: The warm_step for the optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for k2.ctc_loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for conformer - "subsampling_factor": 4, - "feature_dim": 80, - "attention_dim": 512, - "nhead": 4, - "num_encoder_layers": 12, - "num_decoder_layers": 6, - "use_feat_batchnorm": True, - # parameters for Noam - "weight_decay": 1e-6, - "lr_factor": 5.0, - "warm_step": 80000, - "use_pruned_intersect": False, - "den_scale": 1.0, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: MmiTrainingGraphCompiler, - is_training: bool, -) -> Tuple[torch.Tensor, MetricsTracker]: - """ - Compute LF-MMI loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `LFMMILoss.forward()` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - loss_fn = LFMMILoss( - graph_compiler=graph_compiler, - den_scale=params.den_scale, - use_pruned_intersect=params.use_pruned_intersect, - ) - - mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) - - if params.att_rate != 0.0: - token_ids = graph_compiler.texts_to_ids(supervisions["text"]) - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss - else: - loss = mmi_loss - att_loss = torch.tensor([0]) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["mmi_loss"] = mmi_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: MmiTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: MmiTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = MmiTrainingGraphCompiler( - params.lang_dir, - device=device, - oov="", - sos_id=1, - eos_id=1, - ) - - logging.info("About to create model") - if params.att_rate == 0: - assert params.num_decoder_layers == 0, f"{params.num_decoder_layers}" - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_encoder_layers=params.num_encoder_layers, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) - - if checkpoints and checkpoints["optimizer"]: - optimizer.load_state_dict(checkpoints["optimizer"]) - - aishell = AishellAsrDataModule(args) - train_cuts = aishell.train_cuts() - train_dl = aishell.train_dataloaders(train_cuts) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py deleted file mode 100644 index a3e50e385..000000000 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ /dev/null @@ -1,924 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling, VggSubsampling -from torch.nn.utils.rnn import pad_sequence - -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] - - -class Transformer(nn.Module): - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder/decoder. - num_encoder_layers: - Number of encoder layers. - num_decoder_layers: - Number of decoder layers. - dropout: - Dropout in encoder/decoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. - """ - super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) - - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_classes) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None - - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) - ) - - if num_decoder_layers > 0: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol - - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None - - self.decoder = nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) - - self.decoder_criterion = LabelSmoothingLoss() - else: - self.decoder_criterion = None - - def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C). - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is (N, T, C) - - Encoder output with shape (T, N, C). It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is (N, T). - It is None if `supervision` is None. - """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - - def run_encoder( - self, x: torch.Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - Returns: - Return a tuple with two tensors: - - The encoder output, with shape (T, N, C) - - encoder padding mask, with shape (N, T). - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - - return x, mask - - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The output tensor from the transformer encoder. - Its shape is (T, N, C) - - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is (N, T, C) - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - - @torch.jit.export - def decoder_forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - @torch.jit.export - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[torch.Tensor], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - if isinstance(token_ids[0], torch.Tensor): - # This branch is executed by torchscript in C++. - # See https://github.com/k2-fsa/k2/pull/870 - # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 - token_ids = [tolist(t) for t in token_ids] - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, - reduction="none", - ) - - nll = nll.view(pred_pad.shape[0], -1) - - return nll - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -class TransformerDecoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerDecoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - tgt: - the sequence to the decoder layer (required). - memory: - the sequence from the last layer of the encoder (required). - tgt_mask: - the mask for the tgt sequence (optional). - memory_mask: - the mask for the memory sequence (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch (optional). - memory_key_padding_mask: - the mask for the memory keys per batch (optional). - - Shape: - tgt: (T, N, E). - memory: (S, N, E). - tgt_mask: (T, T). - memory_mask: (T, S). - tgt_key_padding_mask: (N, T). - memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - tgt2 = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, - )[0] - tgt = residual + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm2(tgt) - tgt2 = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = residual + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = residual + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - return tgt - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - - -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO:: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), - True denote the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask - - -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - return [[sos_id] + utt for utt in token_ids] - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - return [utt + [eos_id] for utt in token_ids] - - -def tolist(t: torch.Tensor) -> List[int]: - """Used by jit""" - return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/aishell/ASR/local/compile_hlg.py b/egs/aishell/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/aishell/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/compile_lg.py b/egs/aishell/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/aishell/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py deleted file mode 100755 index 6a9bb4f42..000000000 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aidatatang_200zh dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "train", - "test", - "dev", - ) - prefix = "aidatatang" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - - for sup in m["supervisions"]: - sup.custom = {"origin": "aidatatang_200zh"} - - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_aidatatang_200zh( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed - ) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py deleted file mode 100755 index 3c48f0aa1..000000000 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aishell dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_aishell( - num_mel_bins: int = 80, - perturb_speed: bool = False, - whisper_fbank: bool = False, - output_dir: str = "data/fbank", -): - src_dir = Path("data/manifests") - output_dir = Path(output_dir) - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "train", - "dev", - "test", - ) - prefix = "aishell" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info("Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - parser.add_argument( - "--output-dir", - type=str, - default="data/fbank", - help="Output directory. Default: data/fbank.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_aishell( - num_mel_bins=args.num_mel_bins, - perturb_speed=args.perturb_speed, - whisper_fbank=args.whisper_fbank, - output_dir=args.output_dir, - ) diff --git a/egs/aishell/ASR/local/compute_fbank_musan.py b/egs/aishell/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/aishell/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py b/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/display_manifest_statistics.py b/egs/aishell/ASR/local/display_manifest_statistics.py deleted file mode 100755 index c478f7331..000000000 --- a/egs/aishell/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,196 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer_stateless/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - # path = "./data/fbank/aishell_cuts_train.jsonl.gz" - # path = "./data/fbank/aishell_cuts_test.jsonl.gz" - path = "./data/fbank/aishell_cuts_dev.jsonl.gz" - # path = "./data/fbank/aidatatang_cuts_train.jsonl.gz" - # path = "./data/fbank/aidatatang_cuts_test.jsonl.gz" - # path = "./data/fbank/aidatatang_cuts_dev.jsonl.gz" - - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -## train (after speed perturb) -Cuts count: 360294 -Total duration (hours): 455.6 -Speech duration (hours): 455.6 (100.0%) -*** -Duration statistics (seconds): -mean 4.6 -std 1.4 -min 1.1 -0.1% 1.8 -0.5% 2.2 -1% 2.3 -5% 2.7 -10% 3.0 -10% 3.0 -25% 3.5 -50% 4.3 -75% 5.4 -90% 6.5 -95% 7.2 -99% 8.8 -99.5% 9.4 -99.9% 10.9 -max 16.1 - -## test -Cuts count: 7176 -Total duration (hours): 10.0 -Speech duration (hours): 10.0 (100.0%) -*** -Duration statistics (seconds): -mean 5.0 -std 1.6 -min 1.9 -0.1% 2.2 -0.5% 2.4 -1% 2.6 -5% 3.0 -10% 3.2 -10% 3.2 -25% 3.8 -50% 4.7 -75% 5.9 -90% 7.3 -95% 8.2 -99% 9.9 -99.5% 10.7 -99.9% 11.9 -max 14.7 - -## dev -Cuts count: 14326 -Total duration (hours): 18.1 -Speech duration (hours): 18.1 (100.0%) -*** -Duration statistics (seconds): -mean 4.5 -std 1.3 -min 1.6 -0.1% 2.1 -0.5% 2.3 -1% 2.4 -5% 2.9 -10% 3.1 -10% 3.1 -25% 3.5 -50% 4.3 -75% 5.4 -90% 6.4 -95% 7.0 -99% 8.4 -99.5% 8.9 -99.9% 10.3 -max 12.5 - -## aidatatang_200zh (train) -Cuts count: 164905 -Total duration (hours): 139.9 -Speech duration (hours): 139.9 (100.0%) -*** -Duration statistics (seconds): -mean 3.1 -std 1.1 -min 1.1 -0.1% 1.5 -0.5% 1.7 -1% 1.8 -5% 2.0 -10% 2.1 -10% 2.1 -25% 2.3 -50% 2.7 -75% 3.4 -90% 4.6 -95% 5.4 -99% 7.1 -99.5% 7.8 -99.9% 9.1 -max 16.3 - -## aidatatang_200zh (test) -Cuts count: 48144 -Total duration (hours): 40.2 -Speech duration (hours): 40.2 (100.0%) -*** -Duration statistics (seconds): -mean 3.0 -std 1.1 -min 0.9 -0.1% 1.5 -0.5% 1.8 -1% 1.8 -5% 2.0 -10% 2.1 -10% 2.1 -25% 2.3 -50% 2.6 -75% 3.4 -90% 4.4 -95% 5.2 -99% 6.9 -99.5% 7.5 -99.9% 9.0 -max 21.8 - -## aidatatang_200zh (dev) -Cuts count: 24216 -Total duration (hours): 20.2 -Speech duration (hours): 20.2 (100.0%) -*** -Duration statistics (seconds): -mean 3.0 -std 1.0 -min 1.2 -0.1% 1.6 -0.5% 1.7 -1% 1.8 -5% 2.0 -10% 2.1 -10% 2.1 -25% 2.3 -50% 2.7 -75% 3.4 -90% 4.4 -95% 5.1 -99% 6.7 -99.5% 7.3 -99.9% 8.8 -max 11.3 -""" diff --git a/egs/aishell/ASR/local/generate_unique_lexicon.py b/egs/aishell/ASR/local/generate_unique_lexicon.py deleted file mode 120000 index c0aea1403..000000000 --- a/egs/aishell/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py deleted file mode 100755 index 8cc0502c2..000000000 --- a/egs/aishell/ASR/local/prepare_char.py +++ /dev/null @@ -1,259 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/text, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -import re -from pathlib import Path -from typing import Dict, List - -import k2 -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: - """Check if all the given tokens are in token symbol table. - - Args: - token_sym_table: - Token symbol table that contains all the valid tokens. - tokens: - A list of tokens. - Returns: - Return True if there is any token not in the token_sym_table, - otherwise False. - """ - for tok in tokens: - if tok not in token_sym_table: - return True - return False - - -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: - """Generate a lexicon from a word list and token_sym_table. - - Args: - token_sym_table: - Token symbol table that mapping token to token ids. - words: - A list of strings representing words. - Returns: - Return a dict whose keys are words and values are the corresponding - tokens. - """ - lexicon = [] - for word in words: - chars = list(word.strip(" \t")) - if contain_oov(token_sym_table, chars): - continue - lexicon.append((word, chars)) - - # The OOV word is - lexicon.append(("", [""])) - return lexicon - - -def generate_tokens(text_file: str) -> Dict[str, int]: - """Generate tokens from the given text file. - - Args: - text_file: - A file that contains text lines to generate tokens. - Returns: - Return a dict whose keys are tokens and values are token ids ranged - from 0 to len(keys) - 1. - """ - tokens: Dict[str, int] = dict() - tokens[""] = 0 - tokens[""] = 1 - tokens[""] = 2 - whitespace = re.compile(r"([ \t\r\n]+)") - with open(text_file, "r", encoding="utf-8") as f: - for line in f: - line = re.sub(whitespace, "", line) - chars = list(line) - for char in chars: - if char not in tokens: - tokens[char] = len(tokens) - return tokens - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the bpe.model and words.txt - """, - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - text_file = lang_dir / "text" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - token_sym_table = generate_tokens(text_file) - - lexicon = generate_lexicon(token_sym_table, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py deleted file mode 100755 index e7995680b..000000000 --- a/egs/aishell/ASR/local/prepare_char_lm_training_data.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey -# Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes a `tokens.txt` and a text file such as -./download/lm/aishell-transcript.txt -and outputs the LM training data to a supplied directory such -as data/lm_training_char. The format is as follows: -It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a -representation of a dict with the same format with librispeech receipe -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-char", - type=str, - help="""Lang dir of asr model, e.g. data/lang_char""", - ) - parser.add_argument( - "--lm-data", - type=str, - help="""Input LM training data as text, e.g. - download/lm/aishell-train-word.txt""", - ) - parser.add_argument( - "--lm-archive", - type=str, - help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt; - look at the source of this script to see the format.""", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - if Path(args.lm_archive).exists(): - logging.warning(f"{args.lm_archive} exists - skipping") - return - - # make token_dict from tokens.txt in order to map characters to tokens. - token_dict = {} - token_file = args.lang_char + "/tokens.txt" - - with open(token_file, "r") as f: - for line in f.readlines(): - line_list = line.split() - token_dict[line_list[0]] = int(line_list[1]) - - # word2index is a dictionary from words to integer ids. No need to reserve - # space for epsilon, etc.; the words are just used as a convenient way to - # compress the sequences of tokens. - word2index = dict() - - word2token = [] # Will be a list-of-list-of-int, representing tokens. - sentences = [] # Will be a list-of-list-of-int, representing word-ids. - - if "aishell-lm" in args.lm_data: - num_lines_in_total = 120098.0 - step = 50000 - elif "valid" in args.lm_data: - num_lines_in_total = 14326.0 - step = 3000 - elif "test" in args.lm_data: - num_lines_in_total = 7176.0 - step = 3000 - else: - num_lines_in_total = None - step = None - - processed = 0 - - with open(args.lm_data) as f: - while True: - line = f.readline() - if line == "": - break - - if step and processed % step == 0: - logging.info( - f"Processed number of lines: {processed} " - f"({processed / num_lines_in_total * 100: .3f}%)" - ) - processed += 1 - - line_words = line.split() - for w in line_words: - if w not in word2index: - w_token = [] - for t in w: - if t in token_dict: - w_token.append(token_dict[t]) - else: - w_token.append(token_dict[""]) - word2index[w] = len(word2token) - word2token.append(w_token) - sentences.append([word2index[w] for w in line_words]) - - logging.info("Constructing ragged tensors") - words = k2.ragged.RaggedTensor(word2token) - sentences = k2.ragged.RaggedTensor(sentences) - - output = dict(words=words, sentences=sentences) - - num_sentences = sentences.dim0 - logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") - sentence_lengths = [0] * num_sentences - for i in range(num_sentences): - if step and i % step == 0: - logging.info( - f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)" - ) - - word_ids = sentences[i] - - # NOTE: If word_ids is a tensor with only 1 entry, - # token_ids is a torch.Tensor - token_ids = words[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - - # token_ids is a 1-D tensor containing the BPE tokens - # of the current sentence - - sentence_lengths[i] = token_ids.numel() - - output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) - - torch.save(output, args.lm_archive) - logging.info(f"Saved to {args.lm_archive}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py deleted file mode 100755 index c8cf9b881..000000000 --- a/egs/aishell/ASR/local/prepare_lang.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon - -Lexicon = List[Tuple[str, List[str]]] - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") - return parser.parse_args() - - -def main(): - out_dir = Path(get_args().lang_dir) - lexicon_filename = out_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/local/prepare_lang_bbpe.py b/egs/aishell/ASR/local/prepare_lang_bbpe.py deleted file mode 100755 index ddd90622e..000000000 --- a/egs/aishell/ASR/local/prepare_lang_bbpe.py +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/bbpe.model, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import sentencepiece as spm -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - -from icefall.byte_utils import byte_encode -from icefall.utils import str2bool, tokenize_by_CJK_char - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def generate_lexicon( - model_file: str, words: List[str], oov: str -) -> Tuple[Lexicon, Dict[str, int]]: - """Generate a lexicon from a BPE model. - - Args: - model_file: - Path to a sentencepiece model. - words: - A list of strings representing words. - oov: - The out of vocabulary word in lexicon. - Returns: - Return a tuple with two elements: - - A dict whose keys are words and values are the corresponding - word pieces. - - A dict representing the token symbol, mapping from tokens to IDs. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - - # Convert word to word piece IDs instead of word piece strings - # to avoid OOV tokens. - encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words] - words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int) - - # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] - - lexicon = [] - for word, pieces in zip(words, words_pieces): - lexicon.append((word, pieces)) - - lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) - - token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} - - return lexicon, token2id - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the bpe.model and words.txt - """, - ) - - parser.add_argument( - "--oov", - type=str, - default="", - help="The out of vocabulary word in lexicon.", - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - - See "test/test_bpe_lexicon.py" for usage. - """, - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - model_file = lang_dir / "bbpe.model" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", args.oov, "#0", "", ""] - - for w in excluded: - if w in words: - words.remove(w) - - lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/local/prepare_lang_fst.py b/egs/aishell/ASR/local/prepare_lang_fst.py deleted file mode 120000 index c5787c534..000000000 --- a/egs/aishell/ASR/local/prepare_lang_fst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/sort_lm_training_data.py b/egs/aishell/ASR/local/sort_lm_training_data.py deleted file mode 120000 index 1d6ccbe33..000000000 --- a/egs/aishell/ASR/local/sort_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py deleted file mode 100755 index 74e025ad7..000000000 --- a/egs/aishell/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/local/train_bbpe_model.py b/egs/aishell/ASR/local/train_bbpe_model.py deleted file mode 100755 index 48160897d..000000000 --- a/egs/aishell/ASR/local/train_bbpe_model.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -import tempfile -from pathlib import Path - -import sentencepiece as spm - -from icefall import byte_encode, tokenize_by_CJK_char - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def _convert_to_bchar(in_path: str, out_path: str): - with open(out_path, "w") as f: - for line in open(in_path, "r").readlines(): - f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n") - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - model_file = Path(model_prefix + ".model") - if model_file.is_file(): - print(f"{model_file} exists - skipping") - return - - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - temp = tempfile.NamedTemporaryFile() - train_text = temp.name - - _convert_to_bchar(args.transcript, train_text) - - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - - shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh deleted file mode 100755 index 13be69534..000000000 --- a/egs/aishell/ASR/prepare.sh +++ /dev/null @@ -1,391 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=-1 -stop_stage=11 -perturb_speed=true - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/aishell -# You can find data_aishell, resource_aishell inside it. -# You can download them from https://www.openslr.org/33 -# -# - $dl_dir/lm -# This directory contains the language model downloaded from -# https://huggingface.co/pkufool/aishell_lm -# -# - 3-gram.unpruned.arpa -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bbpe_xxx, -# data/lang_bbpe_yyy if the array contains xxx, yyy -vocab_sizes=( - # 2000 - # 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" - - # If you have pre-downloaded it to /path/to/aishell, - # you can create a symlink - # - # ln -sfv /path/to/aishell $dl_dir/aishell - # - # The directory structure is - # aishell/ - # |-- data_aishell - # | |-- transcript - # | `-- wav - # `-- resource_aishell - # |-- lexicon.txt - # `-- speaker.info - - if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then - lhotse download aishell $dl_dir - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/musan - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare aishell manifest" - # We assume that you have downloaded the aishell corpus - # to $dl_dir/aishell - if [ ! -f data/manifests/.aishell_manifests.done ]; then - mkdir -p data/manifests - lhotse prepare aishell $dl_dir/aishell data/manifests - touch data/manifests/.aishell_manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for aishell" - if [ ! -f data/fbank/.aishell.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} - touch data/fbank/.aishell.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -lang_phone_dir=data/lang_phone -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" - mkdir -p $lang_phone_dir - - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/aishell/resource_aishell/lexicon.txt | - sort | uniq > $lang_phone_dir/lexicon.txt - - ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir - - if [ ! -f $lang_phone_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_phone_dir - fi - - - # Train a bigram P for MMI training - if [ ! -f $lang_phone_dir/transcript_words.txt ]; then - log "Generate data to train phone based bigram P" - aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt - aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid - find $dl_dir/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid - awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | - cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt - fi - - if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_phone_dir/uniq_lexicon.txt \ - --transcript $lang_phone_dir/transcript_words.txt \ - --oov "" \ - > $lang_phone_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_phone_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_phone_dir/transcript_tokens.txt \ - -lm $lang_phone_dir/P.arpa - fi - - if [ ! -f $lang_phone_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_phone_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_phone_dir/P.arpa > $lang_phone_dir/P.fst.txt - fi -fi - -lang_char_dir=data/lang_char -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare char based lang" - mkdir -p $lang_char_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - - # The transcripts in training set, generated in stage 5 - cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt - - cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | - cut -d " " -f 2- > $lang_char_dir/text - - (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ - > $lang_char_dir/words.txt - - cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ - | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt - - num_lines=$(< $lang_char_dir/words.txt wc -l) - (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ - >> $lang_char_dir/words.txt - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py --lang-dir $lang_char_dir - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare Byte BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bbpe_${vocab_size} - mkdir -p $lang_dir - - cp $lang_char_dir/words.txt $lang_dir - cp $lang_char_dir/text $lang_dir - - if [ ! -f $lang_dir/bbpe.model ]; then - ./local/train_bbpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/text - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bbpe.py --lang-dir $lang_dir - fi - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare G" - - mkdir -p data/lm - - # Train LM on transcripts - if [ ! -f data/lm/3-gram.unpruned.arpa ]; then - python3 ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text $lang_char_dir/transcript_words.txt \ - -lm data/lm/3-gram.unpruned.arpa - fi - - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="$lang_phone_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt - - python3 -m kaldilm \ - --read-symbol-table="$lang_char_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt - fi - - if [ ! -f $lang_char_dir/HLG.fst ]; then - ./local/prepare_lang_fst.py \ - --lang-dir $lang_char_dir \ - --ngram-G ./data/lm/G_3_gram_char.fst.txt - fi -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compile LG & HLG" - ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone - ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bbpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char - done - - ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone - ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bbpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char - done -fi - -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Generate LM training data" - - log "Processing char based data" - out_dir=data/lm_training_char - mkdir -p $out_dir $dl_dir/lm - - if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then - cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt - fi - - # training words - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $dl_dir/lm/aishell-train-word.txt \ - --lm-archive $out_dir/lm_data.pt - - # valid words - if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then - aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt - aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid - find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid - awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | - cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt - fi - - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $dl_dir/lm/aishell-valid-word.txt \ - --lm-archive $out_dir/lm_data_valid.pt - - # test words - if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then - aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt - aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid - find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid - awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | - cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt - fi - - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $dl_dir/lm/aishell-test-word.txt \ - --lm-archive $out_dir/lm_data_test.pt -fi - - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Sort LM training data" - # Sort LM training data by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of tokens - # in a sentence. - - out_dir=data/lm_training_char - mkdir -p $out_dir - ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data_valid.pt \ - --out-lm-data $out_dir/sorted_lm_data-valid.pt \ - --out-statistics $out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data_test.pt \ - --out-lm-data $out_dir/sorted_lm_data-test.pt \ - --out-statistics $out_dir/statistics-test.txt -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Train RNN LM model" - python ../../../icefall/rnn_lm/train.py \ - --start-epoch 0 \ - --world-size 1 \ - --num-epochs 20 \ - --use-fp16 0 \ - --embedding-dim 512 \ - --hidden-dim 512 \ - --num-layers 2 \ - --batch-size 400 \ - --exp-dir rnnlm_char/exp \ - --lm-data $out_dir/sorted_lm_data.pt \ - --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ - --vocab-size 4336 \ - --master-port 12345 -fi - -# whisper large-v3 using 128 mel bins, others using 80 mel bins -whisper_mel_bins=80 -output_dir=data/fbank_whisper -if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then - log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning" - if [ ! -f $output_dir/.aishell.whisper.done ]; then - mkdir -p $output_dir - ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir - ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir - touch $output_dir/.aishell.whisper.done - fi -fi diff --git a/egs/aishell/ASR/prepare_aidatatang_200zh.sh b/egs/aishell/ASR/prepare_aidatatang_200zh.sh deleted file mode 100755 index ec89450df..000000000 --- a/egs/aishell/ASR/prepare_aidatatang_200zh.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/aidatatang_200zh -# You can find "corpus" and "transcript" inside it. -# You can download it at -# https://openslr.org/62/ - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - if [ ! -f $dl_dir/aidatatang_200zh/transcript/aidatatang_200_zh_transcript.txt ]; then - lhotse download aidatatang-200zh $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare manifest" - # We assume that you have downloaded the aidatatang_200zh corpus - # to $dl_dir/aidatatang_200zh - if [ ! -f data/manifests/.aidatatang_200zh_manifests.done ]; then - mkdir -p data/manifests - lhotse prepare aidatatang-200zh $dl_dir data/manifests - touch data/manifests/.aidatatang_200zh_manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Process aidatatang_200zh" - if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py --perturb-speed True - touch data/fbank/.aidatatang_200zh_fbank.done - fi -fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless2/asr_datamodule.py deleted file mode 120000 index fa1b8cca3..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/conformer.py b/egs/aishell/ASR/pruned_transducer_stateless2/conformer.py deleted file mode 120000 index a65957180..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py deleted file mode 100755 index f41ea6776..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1,553 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - token_table: - It maps token ID to a string. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - dev_cuts = aishell.valid_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - dev_dl = aishell.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless2/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless2/encoder_interface.py deleted file mode 120000 index f58253127..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py deleted file mode 100755 index c2dc0d5f3..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ /dev/null @@ -1,205 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --jit 0 \ - --epoch 29 \ - --avg 5 - -It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt - -To use the generated file with `pruned_transducer_stateless2/decode.py`, -you can do:: - - cd /path/to/exp_dir - ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt - - cd /path/to/egs/aishell/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=29, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default=Path("pruned_transducer_stateless2/exp"), - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = ( - params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" - ) - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless2/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/model.py b/egs/aishell/ASR/pruned_transducer_stateless2/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/optim.py b/egs/aishell/ASR/pruned_transducer_stateless2/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py deleted file mode 100755 index c4aa98358..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ /dev/null @@ -1,328 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lens = [f.size(0) for f in features] - feature_lens = torch.tensor(feature_lens, device=device) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) - - num_waves = encoder_out.size(0) - hyp_list = [] - logging.info(f"Using {params.method}") - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_list = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_list = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.method == "modified_beam_search": - hyp_list = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - hyp_list.append(hyp) - - hyps = [] - for hyp in hyp_list: - hyps.append([lexicon.token_table[i] for i in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless2/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py deleted file mode 100755 index 60f014c48..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ /dev/null @@ -1,1036 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# Copyright 2021 (Pingfeng Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - - -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 90 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --max-duration 200 \ - -""" - - -import argparse -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse import CutSet -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=2048, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - rng: random.Random, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], " - f"batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train, - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 12.0 - - return cuts - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - rng = random.Random(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - logging.info(f"start training from epoch {params.start_epoch}") - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - rng=rng, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py b/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py deleted file mode 120000 index 9a799406b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified-2/aidatatang_200zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py b/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py deleted file mode 120000 index 1b5f38a54..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified-2/aishell.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py deleted file mode 120000 index ae3bdd1e0..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified-2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py b/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py deleted file mode 120000 index c7c1a4b6e..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py deleted file mode 100755 index 3901a330c..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ /dev/null @@ -1,795 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 - -(5) modified beam search (with LM shallow fusion) -./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search_lm_shallow_fusion \ - --beam-size 4 \ - --lm-type rnn \ - --lm-scale 0.3 \ - --lm-exp-dir /path/to/LM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 - -(6) modified beam search with LM shallow fusion + LODR -./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --max-duration 600 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --decoding-method modified_beam_search_LODR \ - --beam-size 4 \ - --lm-type rnn \ - --lm-scale 0.48 \ - --lm-exp-dir /path/to/LM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 - --tokens-ngram 2 \ - --ngram-lm-scale -0.28 \ -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from aishell import AIShell -from asr_datamodule import AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""Token Ngram used for rescoring. - Used only when the decoding method is - modified_beam_search_ngram_rescoring""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="""ID of the backoff symbol. - Used only when the decoding method is - modified_beam_search_ngram_rescoring""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - LM: Optional[LmScorer] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - token_table: - It maps token ID to a string. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - LM: Optional[LmScorer] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - batch=batch, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - LM=LM, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - params.datatang_prob = 0 - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - if "ngram" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if params.use_shallow_fusion: - if params.lm_type == "rnn": - params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" - elif params.lm_type == "transformer": - params.suffix += f"-transformer-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - # only load N-gram LM when needed - if "ngram" in params.decoding_method or "LODR" in params.decoding_method: - lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - lm_filename, - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - # only load the neural network LM if doing shallow fusion - if params.use_shallow_fusion: - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - - else: - LM = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - asr_datamodule = AsrDataModule(args) - aishell = AIShell(manifest_dir=args.manifest_dir) - test_cuts = aishell.test_cuts() - dev_cuts = aishell.valid_cuts() - test_dl = asr_datamodule.test_dataloaders(test_cuts) - dev_dl = asr_datamodule.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - LM=LM, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py deleted file mode 120000 index f58253127..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py deleted file mode 100755 index 2248c7a08..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ /dev/null @@ -1,273 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --jit 0 \ - --epoch 29 \ - --avg 5 - -It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt - -To use the generated file with `pruned_transducer_stateless3/decode.py`, -you can do:: - - cd /path/to/exp_dir - ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt - - cd /path/to/egs/aishell/ASR - ./pruned_transducer_stateless3/decode.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=29, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default=Path("pruned_transducer_stateless3/exp"), - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - params.datatang_prob = 0 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = ( - params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" - ) - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py deleted file mode 120000 index 557e18aa1..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py deleted file mode 100644 index a4dda0d6d..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - decoder_datatang: Optional[nn.Module] = None, - joiner_datatang: Optional[nn.Module] = None, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size). - Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - encoder_dim: - Output dimension of the encoder network. - decoder_dim: - Output dimension of the decoder network. - joiner_dim: - Input dimension of the joiner network. - vocab_size: - Output dimension of the joiner network. - decoder_datatang: - Optional. The decoder network for the aidatatang_200zh dataset. - joiner_datatang: - Optional. The joiner network for the aidatatang_200zh dataset. - """ - super().__init__() - - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.decoder_datatang = decoder_datatang - self.joiner_datatang = joiner_datatang - - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - - if decoder_datatang is not None: - self.simple_am_proj_datatang = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) - self.simple_lm_proj_datatang = ScaledLinear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - aishell: bool = True, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - warmup: float = 1.0, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - aishell: - True to use the decoder and joiner for the aishell dataset. - False to use the decoder and joiner for the aidatatang_200zh - dataset. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(encoder_out_lens > 0) - - if aishell: - decoder = self.decoder - simple_lm_proj = self.simple_lm_proj - simple_am_proj = self.simple_am_proj - joiner = self.joiner - else: - decoder = self.decoder_datatang - simple_lm_proj = self.simple_lm_proj_datatang - simple_am_proj = self.simple_am_proj_datatang - joiner = self.joiner_datatang - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = encoder_out_lens - - lm = simple_lm_proj(decoder_out) - am = simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=joiner.encoder_proj(encoder_out), - lm=joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - return (simple_loss, pruned_loss) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/optim.py b/egs/aishell/ASR/pruned_transducer_stateless3/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py deleted file mode 100755 index 69fe3a40b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ /dev/null @@ -1,329 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless3/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless3/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless3/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless3/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - params.datatang_prob = 0 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lens = [f.size(0) for f in features] - feature_lens = torch.tensor(feature_lens, device=device) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) - - num_waves = encoder_out.size(0) - hyp_list = [] - logging.info(f"Using {params.method}") - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_list = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_list = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.method == "modified_beam_search": - hyp_list = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - hyp_list.append(hyp) - - hyps = [] - for hyp in hyp_list: - hyps.append([lexicon.token_table[i] for i in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py deleted file mode 120000 index db93d155b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py deleted file mode 100755 index 7c23041ca..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ /dev/null @@ -1,1251 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# Copyright 2021 (Pingfeng Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -./prepare.sh - -# If you use a non-zero value for --datatang-prob, you also need to run -./prepare_aidatatang_200zh.sh - -If you use --datatang-prob=0, then you don't need to run the above script. - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - - -./pruned_transducer_stateless3/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 0 \ - --exp-dir pruned_transducer_stateless3/exp \ - --max-duration 300 \ - --datatang-prob 0.2 - -# For mix precision training: - -./pruned_transducer_stateless3/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless3/exp \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from aidatatang_200zh import AIDatatang200zh -from aishell import AIShell -from asr_datamodule import AsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=2048, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--datatang-prob", - type=float, - default=0.0, - help="""The probability to select a batch from the - aidatatang_200zh dataset. - If it is set to 0, you don't need to download the data - for aidatatang_200zh. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - if params.datatang_prob > 0: - decoder_datatang = get_decoder_model(params) - joiner_datatang = get_joiner_model(params) - else: - decoder_datatang = None - joiner_datatang = None - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - decoder_datatang=decoder_datatang, - joiner_datatang=joiner_datatang, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def is_aishell(c: Cut) -> bool: - """Return True if this cut is from the AIShell dataset. - - Note: - During data preparation, we set the custom field in - the supervision segment of aidatatang_200zh to - dict(origin='aidatatang_200zh') - See ../local/process_aidatatang_200zh.py. - """ - return c.supervisions[0].custom is None - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - aishell = is_aishell(supervisions["cut"][0]) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - aishell=aishell, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - datatang_train_dl: Optional[torch.utils.data.DataLoader], - valid_dl: torch.utils.data.DataLoader, - rng: random.Random, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - aishell_tot_loss = MetricsTracker() - datatang_tot_loss = MetricsTracker() - tot_loss = MetricsTracker() - - # index 0: for LibriSpeech - # index 1: for GigaSpeech - # This sets the probabilities for choosing which datasets - dl_weights = [1 - params.datatang_prob, params.datatang_prob] - - iter_aishell = iter(train_dl) - if datatang_train_dl is not None: - iter_datatang = iter(datatang_train_dl) - - batch_idx = 0 - - while True: - if datatang_train_dl is not None: - idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] - dl = iter_aishell if idx == 0 else iter_datatang - else: - dl = iter_aishell - - try: - batch = next(dl) - except StopIteration: - break - batch_idx += 1 - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - aishell = is_aishell(batch["supervisions"]["cut"][0]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - if datatang_train_dl is not None: - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - if aishell: - aishell_tot_loss = ( - aishell_tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info - prefix = "aishell" # for logging only - else: - datatang_tot_loss = ( - datatang_tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info - prefix = "datatang" - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - if datatang_train_dl is not None: - datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " - tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, " - else: - tot_loss_str = "" - datatang_str = "" - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, {prefix}_loss[{loss_info}], " - f"{tot_loss_str}" - f"aishell_tot_loss[{aishell_tot_loss}], " - f"{datatang_str}" - f"batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, - f"train/current_{prefix}_", - params.batch_idx_train, - ) - if datatang_train_dl is not None: - # If it is None, tot_loss is the same as aishell_tot_loss. - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - aishell_tot_loss.write_summary( - tb_writer, "train/aishell_tot_", params.batch_idx_train - ) - if datatang_train_dl is not None: - datatang_tot_loss.write_summary( - tb_writer, "train/datatang_tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - if datatang_train_dl is not None: - loss_value = tot_loss["loss"] / tot_loss["frames"] - else: - loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 12.0 - - return cuts - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - rng = random.Random(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - if params.datatang_prob > 0: - find_unused_parameters = True - else: - find_unused_parameters = False - - model = DDP( - model, - device_ids=[rank], - find_unused_parameters=find_unused_parameters, - ) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - aishell = AIShell(manifest_dir=args.manifest_dir) - train_cuts = aishell.train_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts) - - if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") - else: - cuts_musan = None - - asr_datamodule = AsrDataModule(args) - - train_dl = asr_datamodule.train_dataloaders( - train_cuts, - on_the_fly_feats=False, - cuts_musan=cuts_musan, - ) - - if params.datatang_prob > 0: - datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) - train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) - train_datatang_cuts = train_datatang_cuts.repeat(times=None) - datatang_train_dl = asr_datamodule.train_dataloaders( - train_datatang_cuts, - on_the_fly_feats=False, - cuts_musan=cuts_musan, - ) - else: - datatang_train_dl = None - logging.info("Not using aidatatang_200zh for training") - - valid_cuts = aishell.valid_cuts() - valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - - for dl in [ - train_dl, - # datatang_train_dl - ]: - if dl is not None: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - logging.info(f"start training from epoch {params.start_epoch}") - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - if datatang_train_dl is not None: - datatang_train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - datatang_train_dl=datatang_train_dl, - valid_dl=valid_dl, - rng=rng, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - assert 0 <= args.datatang_prob < 1, args.datatang_prob - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/aishell.py b/egs/aishell/ASR/pruned_transducer_stateless7/aishell.py deleted file mode 120000 index ce581b950..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/aishell.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless3/aishell.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7/asr_datamodule.py deleted file mode 120000 index ae3bdd1e0..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified-2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7/beam_search.py deleted file mode 120000 index e9bbcf2a9..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless3/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py deleted file mode 100755 index d50bccf82..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py +++ /dev/null @@ -1,689 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from aishell import AIShell -from asr_datamodule import AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding_method is modified_beam_search. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding_method is modified_beam_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - token_table: - It maps token ID to a string. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - key = f"beam_size_{params.beam_size}" - if params.has_contexts: - key += f"-context-score-{params.context_score}" - else: - key += "-no-context-words" - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - - store_transcripts(filename=recog_path, texts=results_char, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += "-no-contexts-words" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - if params.decoding_method == "modified_beam_search": - if os.path.exists(params.context_file): - contexts_text = [] - for line in open(params.context_file).readlines(): - contexts_text.append(line.strip()) - contexts = graph_compiler.texts_to_ids(contexts_text) - context_graph = ContextGraph(params.context_score) - context_graph.build([(c, 0.0) for c in contexts]) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - asr_datamodule = AsrDataModule(args) - aishell = AIShell(manifest_dir=args.manifest_dir) - test_cuts = aishell.test_cuts() - dev_cuts = aishell.valid_cuts() - test_dl = asr_datamodule.test_dataloaders(test_cuts) - dev_dl = asr_datamodule.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decoder2.py b/egs/aishell/ASR/pruned_transducer_stateless7/decoder2.py deleted file mode 100644 index 0345db0a8..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decoder2.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - ) - self.blank_id = blank_id - - assert context_size == 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - if torch.jit.is_tracing(): - # This is for exporting to PNNX via ONNX - embedding_out = self.embedding(y) - else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - - embedding_out = F.relu(embedding_out) - return embedding_out diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py deleted file mode 100755 index 058d0ff6b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py +++ /dev/null @@ -1,1255 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# Copyright 2021 (Pingfeng Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -./prepare.sh - -If you use --datatang-prob=0, then you don't need to run the above script. - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from aishell import AIShell -from asr_datamodule import AsrDataModule -from decoder2 import Decoder -from joiner import Joiner -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error() - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 12.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - # T = ((c.num_frames - 7) // 2 + 1) // 2 - # tokens = sp.encode(c.supervisions[0].text, out_type=str) - - # if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - # return False - - return True - - aishell = AIShell(manifest_dir=args.manifest_dir) - train_cuts = aishell.train_cuts() - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") - else: - cuts_musan = None - - asr_datamodule = AsrDataModule(args) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = asr_datamodule.train_dataloaders( - train_cuts, - on_the_fly_feats=False, - cuts_musan=cuts_musan, - sampler_state_dict=sampler_state_dict, - ) - - valid_cuts = aishell.valid_cuts() - valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # graph_compiler=graph_compiler, - # params=params, - # ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - logging.info(f"start training from epoch {params.start_epoch}") - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - raise RuntimeError("Please don't use this file directly!") - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7/encoder_interface.py deleted file mode 120000 index 0c2673d46..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py deleted file mode 100755 index 4981fb71a..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ /dev/null @@ -1,587 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang -# Xiaoyu Yang) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless7/export-onnx.py \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --feedforward-dims "1024,1024,2048,2048,1024" - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py and ./onnx_check.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder2 import Decoder -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from zipformer import Zipformer - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Zipformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - encoder_out, encoder_out_lens = self.encoder(x, x_lens) - - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py deleted file mode 120000 index 2713792e6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py deleted file mode 100755 index 5143f2cae..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ /dev/null @@ -1,279 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir ./data/lang_char \ - --epoch 20 \ - --avg 10 \ - --jit 1 - -Usage of this script: - -./pruned_transducer_stateless7/jit_pretrained.py \ - --nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \ - --lang-dir ./data/lang_char \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - lexicon = Lexicon(args.lang_dir) - token_table = lexicon.token_table - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - x=features, - x_lens=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - hyps = [[token_table[t] for t in tokens] for tokens in hyps] - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/model.py b/egs/aishell/ASR/pruned_transducer_stateless7/model.py deleted file mode 120000 index 0d8bc665b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_check.py deleted file mode 120000 index e97d1c0aa..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py deleted file mode 100755 index 8e8e971eb..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ /dev/null @@ -1,423 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -We use the pre-trained model from -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" - -cd exp -ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 9999 \ - --avg 1 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-9999-avg-1.onnx - - decoder-epoch-9999-avg-1.onnx - - joiner-epoch-9999-avg-1.onnx - -3. Run this file - -./pruned_transducer_stateless3/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import numpy as np -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, joiner_dim) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.run_decoder(decoder_input) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - # current_encoder_out's shape: (batch_size, joiner_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - logits = model.run_joiner(current_encoder_out, decoder_out) - - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = model.run_decoder(decoder_input) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += symbol_table[i] - return text.replace("▁", " ").strip() - - context_size = model.context_size - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp[context_size:]) - s += f"{filename}:\n{words}\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 120000 index 068f0f57f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py deleted file mode 100755 index 2dc835f3b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ /dev/null @@ -1,1254 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -./prepare.sh - -If you use --datatang-prob=0, then you don't need to run the above script. - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from aishell import AIShell -from asr_datamodule import AsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 12.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - # T = ((c.num_frames - 7) // 2 + 1) // 2 - # tokens = sp.encode(c.supervisions[0].text, out_type=str) - - # if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - # return False - - return True - - aishell = AIShell(manifest_dir=args.manifest_dir) - train_cuts = aishell.train_cuts() - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") - else: - cuts_musan = None - - asr_datamodule = AsrDataModule(args) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = asr_datamodule.train_dataloaders( - train_cuts, - on_the_fly_feats=False, - cuts_musan=cuts_musan, - sampler_state_dict=sampler_state_dict, - ) - - valid_cuts = aishell.valid_cuts() - valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # graph_compiler=graph_compiler, - # params=params, - # ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - logging.info(f"start training from epoch {params.start_epoch}") - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7/zipformer.py deleted file mode 120000 index f2f66041e..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py deleted file mode 120000 index 8554e44cc..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py deleted file mode 100755 index 46f542641..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py +++ /dev/null @@ -1,822 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import ( - LmScorer, - NgramLm, - byte_encode, - smart_byte_decode, - tokenize_by_CJK_char, -) -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the byte BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bbpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.25, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - ref_texts = [] - for tx in supervisions["text"]: - ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) - - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(ref_texts), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - key += f"_ilme_scale_{params.ilme_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network LM, used during shallow fusion - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - params.suffix += f"-ilme-scale-{params.ilme_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bbpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - - test_cuts = aishell.test_cuts() - dev_cuts = aishell.valid_cuts() - - test_dl = aishell.test_dataloaders(test_cuts) - dev_dl = aishell.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py deleted file mode 120000 index b9aa0ae08..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py deleted file mode 100755 index 4e82b45d3..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py +++ /dev/null @@ -1,320 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/aishell/ASR - ./pruned_transducer_stateless7_bbpe/decode.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bbpe_500/bbpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe - # You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_bbpe/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bbpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py deleted file mode 100755 index 8fb7ac278..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ /dev/null @@ -1,275 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 49 \ - --avg 28 \ - --jit 1 - -Usage of this script: - -./pruned_transducer_stateless7_bbpe/jit_pretrained.py \ - --nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall import smart_byte_decode - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - x=features, - x_lens=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = smart_byte_decode(sp.decode(hyp)) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py deleted file mode 120000 index 0d8bc665b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py deleted file mode 100755 index 12004315b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ /dev/null @@ -1,346 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 48 \ - --avg 29 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7_bbpe/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./pruned_transducer_stateless7_bbpe/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./pruned_transducer_stateless7_bbpe/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by -./pruned_transducer_stateless7_bbpe/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import smart_byte_decode -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py deleted file mode 120000 index 7ceac5d10..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py deleted file mode 100755 index 811269989..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ /dev/null @@ -1,1249 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_bbpe/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_bbpe/exp \ - --max-duration 400 - -# For mix precision training: - -./pruned_transducer_stateless7_bbpe/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_bbpe/exp \ - --max-duration 800 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut, CutSet -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import byte_encode, diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, - tokenize_by_CJK_char, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_bbpe/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the Byte BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 2000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bbpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - aishell = AishellAsrDataModule(args) - - train_cuts = aishell.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 12.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - def tokenize_and_encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = byte_encode(tokenize_by_CJK_char(text)) - c.supervisions[0].text = text - return c - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_cuts = train_cuts.map(tokenize_and_encode_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = aishell.valid_cuts() - valid_cuts = valid_cuts.map(tokenize_and_encode_text) - - valid_dl = aishell.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py deleted file mode 120000 index f2f66041e..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md deleted file mode 120000 index a784292cd..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/README.md \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py deleted file mode 120000 index 8554e44cc..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py deleted file mode 100755 index 61b929091..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py +++ /dev/null @@ -1,739 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import ContextGraph -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding_method is modified_beam_search. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding_method is modified_beam_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - token_table: - It maps token ID to a string. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - feature_lens += 30 - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, 30), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - key = f"beam_size_{params.beam_size}" - if params.has_contexts: - key += f"-context-score-{params.context_score}" - else: - key += "-no-context-words" - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += "-no-contexts-words" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - if params.decoding_method == "modified_beam_search": - if os.path.exists(params.context_file): - contexts_text = [] - for line in open(params.context_file).readlines(): - contexts_text.append(line.strip()) - contexts = graph_compiler.texts_to_ids(contexts_text) - context_graph = ContextGraph(params.context_score) - context_graph.build([(c, 0.0) for c in contexts]) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - - test_cuts = aishell.test_cuts() - dev_cuts = aishell.valid_cuts() - - test_dl = aishell.test_dataloaders(test_cuts) - dev_dl = aishell.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - import time - - for test_set, test_dl in zip(test_sets, test_dls): - start = time.time() - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - ) - logging.info(f"Elasped time for {test_set}: {time.time() - start}") - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py deleted file mode 120000 index ca8fed319..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py deleted file mode 120000 index 33944d0d2..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py deleted file mode 100755 index 6653d9d9c..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ /dev/null @@ -1,1253 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer_for_ncnn_export_only import Zipformer - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - is_pnnx=True, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - aishell = AishellAsrDataModule(args) - - train_cuts = aishell.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - # T = ((c.num_frames - 7) // 2 + 1) // 2 - # tokens = sp.encode(c.supervisions[0].text, out_type=str) - - # if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - # return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = aishell.valid_cuts() - valid_dl = aishell.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - raise RuntimeError("Please don't use this file directly!") - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py deleted file mode 120000 index b9aa0ae08..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py deleted file mode 120000 index 72e43c297..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py deleted file mode 120000 index 3b36924ef..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py deleted file mode 120000 index eca5e2956..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py deleted file mode 120000 index 57a0cd0a0..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py deleted file mode 120000 index 2acafdc61..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py deleted file mode 120000 index 5d9c6ba00..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py deleted file mode 120000 index 457131699..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py deleted file mode 120000 index 2b8fa3cbb..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py deleted file mode 120000 index ecfb6dd8a..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py deleted file mode 120000 index e17d4f734..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py deleted file mode 120000 index 8eea90e04..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py deleted file mode 120000 index 28bf7bb82..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py deleted file mode 120000 index c8548d459..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py deleted file mode 120000 index ae4d9bb04..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py deleted file mode 120000 index 81ac4a89a..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py deleted file mode 120000 index 9510b8fde..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py deleted file mode 120000 index 2428b74b9..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py deleted file mode 120000 index b8b8ba432..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py deleted file mode 120000 index 92c3904af..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py deleted file mode 120000 index 1199a61d6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py deleted file mode 100755 index a4b5cd588..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ /dev/null @@ -1,633 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -import os -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states - -from icefall import ContextGraph -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - token_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 50 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - [ - token_table[result] - for result in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - [ - token_table[result] - for result in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - if params.decoding_method == "modified_beam_search": - if os.path.exists(params.context_file): - contexts_text = [] - for line in open(params.context_file).readlines(): - contexts_text.append(line.strip()) - contexts = graph_compiler.texts_to_ids(contexts_text) - context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - aishell = AishellAsrDataModule(args) - - test_cuts = aishell.test_cuts() - valid_cuts = aishell.valid_cuts() - - test_sets = ["test", "valid"] - cuts = [test_cuts, valid_cuts] - - for test_set, test_cut in zip(test_sets, cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py deleted file mode 120000 index 1259849e0..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py deleted file mode 100755 index f3b0f1e11..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py +++ /dev/null @@ -1,1250 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - aishell = AishellAsrDataModule(args) - - train_cuts = aishell.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - # T = ((c.num_frames - 7) // 2 + 1) // 2 - # tokens = sp.encode(c.supervisions[0].text, out_type=str) - - # if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - return False - - return True - - # train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = aishell.valid_cuts() - valid_dl = aishell.valid_dataloaders(valid_cuts) - - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # graph_compiler=graph_compiler, - # params=params, - # ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py deleted file mode 120000 index ec183baa7..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py deleted file mode 120000 index d301e1f9b..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/aishell/ASR/shared b/egs/aishell/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/aishell/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/README.md b/egs/aishell/ASR/tdnn_lstm_ctc/README.md deleted file mode 100644 index c003fd419..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/README.md +++ /dev/null @@ -1,4 +0,0 @@ - -Please visit - -for how to run this recipe. diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py deleted file mode 100644 index aacbd153d..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class AishellAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_train.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py deleted file mode 100755 index 05e52f560..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ /dev/null @@ -1,399 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from model import TdnnLstm - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import get_lattice, nbest_decoding, one_best_decoding -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=19, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Supported values are: - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - """, - ) - parser.add_argument( - "--num-paths", - type=int, - default=30, - help="""Number of paths for n-best based decoding method. - Used only when "method" is nbest. - """, - ) - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp/"), - "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), - # parameters for tdnn_lstm_ctc - "subsampling_factor": 3, - "feature_dim": 80, - # parameters for decoding - "search_beam": 20, - "output_beam": 7, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - batch: dict, - lexicon: Lexicon, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if the decoding method is 1best, the key is the string - `no_rescore`. If the decoding method is nbest, the key is the - string `no_rescore-xxx`, xxx is the num_paths. - - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - model: - The neural model. - HLG: - The decoding graph. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains word symbol table. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = HLG.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is [N, T, C] - - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - - nnet_output = model(feature) - # nnet_output is [N, T, C] - - supervisions = batch["supervisions"] - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - assert params.method in ["1best", "nbest"] - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - ) - key = f"no_rescore-{params.num_paths}" - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - lexicon: Lexicon, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. - lexicon: - It contains word symbol table. - Returns: - Return a dict, whose key may be "no-rescore" if decoding method is 1best, - or it may be "no-rescoer-100" if decoding method is nbest. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - lexicon=lexicon, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - # We compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) - HLG = HLG.to(device) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") - - model.to(device) - model.eval() - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # - test_sets = ["test"] - test_dls = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - lexicon=lexicon, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py deleted file mode 100644 index 1731e1ebe..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class TdnnLstm(nn.Module): - def __init__( - self, num_features: int, num_classes: int, subsampling_factor: int = 3 - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - It reduces the number of output frames by this factor. - """ - super().__init__() - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=500, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - nn.Conv1d( - in_channels=500, - out_channels=500, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - nn.Conv1d( - in_channels=500, - out_channels=500, - kernel_size=3, - stride=self.subsampling_factor, # stride: subsampling_factor! - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - ) - self.lstms = nn.ModuleList( - [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] - ) - self.lstm_bnorms = nn.ModuleList( - [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] - ) - self.dropout = nn.Dropout(0.2) - self.linear = nn.Linear(in_features=500, out_features=self.num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Its shape is [N, C, T] - - Returns: - The output tensor has shape [N, T, C] - """ - x = self.tdnn(x) - x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it - for lstm, bnorm in zip(self.lstms, self.lstm_bnorms): - x_new, _ = lstm(x) - x_new = bnorm(x_new.permute(1, 2, 0)).permute( - 2, 0, 1 - ) # (T, N, C) -> (N, C, T) -> (T, N, C) - x_new = self.dropout(x_new) - x = x_new + x # skip connections - x = x.transpose( - 1, 0 - ) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim - x = self.linear(x) - x = nn.functional.log_softmax(x, dim=-1) - return x diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py deleted file mode 100644 index 9754b4939..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ /dev/null @@ -1,221 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from model import TdnnLstm -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Use the best path as decoding output. Only the transformer encoder - output is used for decoding. We call it HLG decoding. - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 3, - "num_classes": 220, - "sample_rate": 16000, - "search_beam": 20, - "output_beam": 7, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - features = features.permute(0, 2, 1) # now features is [N, C, T] - - with torch.no_grad(): - nnet_output = model(features) - # nnet_output is [N, T, C] - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - assert params.method == "1best" - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py deleted file mode 100755 index e574cf89b..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ /dev/null @@ -1,627 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage - export CUDA_VISIBLE_DEVICES="0,1,2,3" - ./tdnn_lstm_ctc/train.py \ - --world-size 4 \ - --num-epochs 20 \ - --max-duration 300 -""" - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional - -import k2 -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import AishellAsrDataModule -from lhotse.utils import fix_random_seed -from model import TdnnLstm -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.optim.lr_scheduler import StepLR -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - """ - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp_lr1e-4"), - "lang_dir": Path("data/lang_phone"), - "lr": 1e-4, - "feature_dim": 80, - "weight_decay": 5e-4, - "subsampling_factor": 3, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 200, - "valid_interval": 1000, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CtcTrainingGraphCompiler, - is_training: bool, -): - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of TdnnLstm in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - assert feature.ndim == 3 - feature = feature.to(device) - - with torch.set_grad_enabled(is_training): - nnet_output = model(feature) - # nnet_output is [N, T, C] - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervisions = batch["supervisions"] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - assert loss.requires_grad == is_training - - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - - return loss - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = 0.0 - tot_frames = 0.0 - for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames - - if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] - - params.valid_loss = tot_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = 0.0 # reset after params.reset_interval of batches - tot_frames = 0.0 # reset after params.reset_interval of batches - - params.tot_loss = 0.0 - params.tot_frames = 0.0 - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" - ) - if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0 - tot_frames = 0 - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) - - params.train_loss = params.tot_loss / params.tot_frames - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - scheduler = StepLR(optimizer, step_size=8, gamma=0.1) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - scheduler.load_state_dict(checkpoints["scheduler"]) - - aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if epoch > params.start_epoch: - logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}") - - if tb_writer is not None: - tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, - ) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - scheduler.step() - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless/README.md b/egs/aishell/ASR/transducer_stateless/README.md deleted file mode 100644 index 622cb837c..000000000 --- a/egs/aishell/ASR/transducer_stateless/README.md +++ /dev/null @@ -1,21 +0,0 @@ -## Introduction - -The decoder, i.e., the prediction network, is from -https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 -(Rnn-Transducer with Stateless Prediction Network) - -You can use the following command to start the training: - -```bash -cd egs/aishell/ASR - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./transducer_stateless/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir transducer_stateless/exp \ - --max-duration 250 \ - --lr-factor 2.5 -``` diff --git a/egs/aishell/ASR/transducer_stateless/__init__.py b/egs/aishell/ASR/transducer_stateless/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell/ASR/transducer_stateless/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless/asr_datamodule.py deleted file mode 120000 index a73848de9..000000000 --- a/egs/aishell/ASR/transducer_stateless/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py deleted file mode 100644 index de0a8d0f5..000000000 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional - -import numpy as np -import torch -from model import Transducer - - -def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - - device = model.device - - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y != blank_id: - hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - return hyp - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys - log_prob: float - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self): - return self._data - - # def add(self, ys: List[int], log_prob: float): - def add(self, hyp: Hypothesis): - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: float) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - that have `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for key, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - - device = model.device - - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out) - - # TODO(fangjun): Ccale the blank posterior - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() - - # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i == blank_id: - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py deleted file mode 100755 index d958a6338..000000000 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ /dev/null @@ -1,461 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Tuple - -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from beam_search import beam_search, greedy_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --decoding-method is beam_search", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=3, - help="Maximum number of symbols per frame", - ) - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - transducer_stateless/exp/pretrained.pt. Note: only model.state_dict() - is saved. pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - batch: dict, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains the token symbol table and the word symbol table. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - hyps.append([lexicon.token_table[i] for i in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - else: - return {f"beam_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ("greedy_search", "beam_search") - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - - # params.blank_id = graph_compiler.texts_to_ids("")[0][0] - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") - return - - model.to(device) - model.eval() - model.device = device - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - - test_sets = ["test"] - test_dls = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py deleted file mode 100644 index 130f080ec..000000000 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, - kernel_size=context_size, - padding=0, - groups=embedding_dim, - bias=False, - ) - else: - # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` - # when inference with torch.jit.script and context_size == 1 - self.conv = nn.Identity() - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U) with blank prepended. - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, embedding_dim). - """ - embedding_out = self.embedding(y) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - return embedding_out diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py deleted file mode 100755 index bfd0ecb0c..000000000 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ /dev/null @@ -1,247 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./transducer_stateless/export.py \ - --exp-dir ./transducer_stateless/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 20 \ - --avg 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `transducer_stateless/decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/aishell/ASR - ./transducer_stateless/decode.py \ - --exp-dir ./transducer_stateless/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 1 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import AttributeDict, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless/joiner.py b/egs/aishell/ASR/transducer_stateless/joiner.py deleted file mode 100644 index 2ef3f1de6..000000000 --- a/egs/aishell/ASR/transducer_stateless/joiner.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn - - -class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): - super().__init__() - - self.output_linear = nn.Linear(input_dim, output_dim) - - def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, C). - decoder_out: - Output from the decoder. Its shape is (N, U, C). - Returns: - Return a tensor of shape (N, T, U, C). - """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) - - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) - - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) - - logit = encoder_out + decoder_out - logit = torch.tanh(logit) - - output = self.output_linear(logit) - - return output diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py deleted file mode 100644 index 591bbe44f..000000000 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain - one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - Returns: - Return the transducer loss. - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - decoder_out = self.decoder(sos_y_padded) - - logits = self.joiner(encoder_out, decoder_out) - - # rnnt_loss requires 0 padded targets - # Note: y does not start with SOS - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary) - - return loss diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py deleted file mode 100755 index 540e7b61b..000000000 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ /dev/null @@ -1,322 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./transducer_stateless/pretrained.py \ - --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav \ - -(1) beam search -./transducer_stateless/pretrained.py \ - --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav \ - -You can also use `./transducer_stateless/exp/epoch-xx.pt`. - -Note: ./transducer_stateless/exp/pretrained.pt is generated by -./transducer_stateless/export.py -""" - - -import argparse -import logging -import math -from pathlib import Path -from typing import List - -import kaldifeat -import torch -import torch.nn as nn -import torchaudio -from beam_search import beam_search, greedy_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer -from torch.nn.utils.rnn import pad_sequence - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --method is beam_search", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=3, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = graph_compiler.texts_to_ids("")[0][0] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append([lexicon.token_table[i] for i in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless/subsampling.py b/egs/aishell/ASR/transducer_stateless/subsampling.py deleted file mode 120000 index 6fee09e58..000000000 --- a/egs/aishell/ASR/transducer_stateless/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless/test_decoder.py b/egs/aishell/ASR/transducer_stateless/test_decoder.py deleted file mode 100755 index fe0bdee70..000000000 --- a/egs/aishell/ASR/transducer_stateless/test_decoder.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/aishell/ASR - python ./transducer_stateless/test_decoder.py -""" - -import torch -from decoder import Decoder - - -def test_decoder(): - vocab_size = 3 - blank_id = 0 - embedding_dim = 128 - context_size = 4 - - decoder = Decoder( - vocab_size=vocab_size, - embedding_dim=embedding_dim, - blank_id=blank_id, - context_size=context_size, - ) - N = 100 - U = 20 - x = torch.randint(low=0, high=vocab_size, size=(N, U)) - y = decoder(x) - assert y.shape == (N, U, embedding_dim) - - # for inference - x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) - y = decoder(x, need_pad=False) - assert y.shape == (N, 1, embedding_dim) - - -def main(): - test_decoder() - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py deleted file mode 100755 index 62ffff473..000000000 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ /dev/null @@ -1,675 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# Copyright 2021 (Pingfeng Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from model import Transducer -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - attention_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for Noam - "warm_step": 80000, # For the 100h subset, use 8k - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - aishell = AishellAsrDataModule(args) - train_cuts = aishell.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 12.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_dl = aishell.train_dataloaders(train_cuts) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py deleted file mode 100644 index b3ff153c1..000000000 --- a/egs/aishell/ASR/transducer_stateless/transformer.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from subsampling import Conv2dSubsampling, VggSubsampling - -from icefall.utils import make_pad_mask - - -class Transformer(EncoderInterface): - def __init__( - self, - num_features: int, - output_dim: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - dropout: float = 0.1, - normalize_before: bool = True, - vgg_frontend: bool = False, - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - output_dim: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder. - num_encoder_layers: - Number of encoder layers. - dropout: - Dropout in encoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - """ - super().__init__() - - self.num_features = num_features - self.output_dim = output_dim - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None - - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) - ) - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - Returns: - Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 - assert x.size(0) == lengths.max().item() - - mask = make_pad_mask(lengths) - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - - logits = self.encoder_output_layer(x) - logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return logits, lengths - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/README.md b/egs/aishell/ASR/transducer_stateless_modified-2/README.md deleted file mode 100644 index b3c539670..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/README.md +++ /dev/null @@ -1,59 +0,0 @@ -## Introduction - -The decoder, i.e., the prediction network, is from -https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 -(Rnn-Transducer with Stateless Prediction Network) - -Different from `../transducer_stateless_modified`, this folder -uses extra data, i.e., http://www.openslr.org/62/, during training. - -You can use the following command to start the training: - -```bash -cd egs/aishell/ASR -./prepare.sh --stop-stage 6 -./prepare_aidatatang_200zh.sh - -export CUDA_VISIBLE_DEVICES="0,1,2" - -./transducer_stateless_modified-2/train.py \ - --world-size 3 \ - --num-epochs 90 \ - --start-epoch 0 \ - --exp-dir transducer_stateless_modified-2/exp-2 \ - --max-duration 250 \ - --lr-factor 2.0 \ - --context-size 2 \ - --modified-transducer-prob 0.25 \ - --datatang-prob 0.2 -``` - -To decode, you can use - -```bash -for epoch in 89; do - for avg in 30 38; do - ./transducer_stateless_modified-2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless_modified-2/exp-2 \ - --max-duration 100 \ - --context-size 2 \ - --decoding-method greedy_search \ - --max-sym-per-frame 1 - done -done - -for epoch in 89; do - for avg in 38; do - ./transducer_stateless_modified-2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless_modified-2/exp-2 \ - --max-duration 100 \ - --context-size 2 \ - --decoding-method modified_beam_search \ - --beam-size 4 - done -done -``` diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/__init__.py b/egs/aishell/ASR/transducer_stateless_modified-2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/aidatatang_200zh.py b/egs/aishell/ASR/transducer_stateless_modified-2/aidatatang_200zh.py deleted file mode 100644 index 26d4ee111..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/aidatatang_200zh.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy - - -class AIDatatang200zh: - def __init__(self, manifest_dir: str): - """ - Args: - manifest_dir: - It is expected to contain the following files:: - - - aidatatang_cuts_dev.jsonl.gz - - aidatatang_cuts_train.jsonl.gz - - aidatatang_cuts_test.jsonl.gz - """ - self.manifest_dir = Path(manifest_dir) - - def train_cuts(self) -> CutSet: - f = self.manifest_dir / "aidatatang_cuts_train.jsonl.gz" - logging.info(f"About to get train cuts from {f}") - cuts_train = load_manifest_lazy(f) - return cuts_train - - def valid_cuts(self) -> CutSet: - f = self.manifest_dir / "aidatatang_cuts_valid.jsonl.gz" - logging.info(f"About to get valid cuts from {f}") - cuts_valid = load_manifest_lazy(f) - return cuts_valid - - def test_cuts(self) -> CutSet: - f = self.manifest_dir / "aidatatang_cuts_test.jsonl.gz" - logging.info(f"About to get test cuts from {f}") - cuts_test = load_manifest_lazy(f) - return cuts_test diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/aishell.py b/egs/aishell/ASR/transducer_stateless_modified-2/aishell.py deleted file mode 100644 index ddeca4d88..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/aishell.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy - - -class AIShell: - def __init__(self, manifest_dir: str): - """ - Args: - manifest_dir: - It is expected to contain the following files:: - - - aishell_cuts_dev.jsonl.gz - - aishell_cuts_train.jsonl.gz - - aishell_cuts_test.jsonl.gz - """ - self.manifest_dir = Path(manifest_dir) - - def train_cuts(self) -> CutSet: - f = self.manifest_dir / "aishell_cuts_train.jsonl.gz" - logging.info(f"About to get train cuts from {f}") - cuts_train = load_manifest_lazy(f) - return cuts_train - - def valid_cuts(self) -> CutSet: - f = self.manifest_dir / "aishell_cuts_dev.jsonl.gz" - logging.info(f"About to get valid cuts from {f}") - cuts_valid = load_manifest_lazy(f) - return cuts_valid - - def test_cuts(self) -> CutSet: - f = self.manifest_dir / "aishell_cuts_test.jsonl.gz" - logging.info(f"About to get test cuts from {f}") - cuts_test = load_manifest_lazy(f) - return cuts_test diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py deleted file mode 100644 index ed453afd2..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# 2022 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import inspect -import logging -from pathlib import Path -from typing import Any, Dict, Optional - -from lhotse import CutSet, Fbank, FbankConfig -from lhotse.dataset import ( - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class AsrDataModule: - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler " - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - on_the_fly_feats: bool, - cuts_musan: Optional[CutSet] = None, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - Cuts for training. - cuts_musan: - If not None, it is the cuts for mixing. - on_the_fly_feats: - True to use OnTheFlyFeatures; - False to use PrecomputedFeatures. - """ - transforms = [] - if cuts_musan is not None: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - input_transforms = [] - - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if on_the_fly_feats - else PrecomputedFeatures() - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - logging.info("About to create train dataloader") - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/beam_search.py b/egs/aishell/ASR/transducer_stateless_modified-2/beam_search.py deleted file mode 120000 index e188617a8..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/conformer.py b/egs/aishell/ASR/transducer_stateless_modified-2/conformer.py deleted file mode 120000 index 88975988f..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified/conformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py deleted file mode 100755 index 57f7a8239..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ /dev/null @@ -1,517 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./transducer_stateless_modified-2/decode.py \ - --epoch 89 \ - --avg 38 \ - --exp-dir ./transducer_stateless_modified-2/exp \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./transducer_stateless_modified-2/decode.py \ - --epoch 89 \ - --avg 38 \ - --exp-dir ./transducer_stateless_modified-2/exp \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./transducer_stateless_modified-2/decode.py \ - --epoch 89 \ - --avg 38 \ - --exp-dir ./transducer_stateless_modified-2/exp \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 -(4) fast beam search -./transducer_stateless_modified-2/decode.py \ - --epoch 89 \ - --avg 38 \ - --exp-dir ./transducer_stateless_modified-2/exp \ - --max-duration 100 \ - --decoding-method fast_beam_search \ - --beam-size 4 \ - --max-contexts 4 \ - --max-states 8 -""" - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from aishell import AIShell -from asr_datamodule import AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless_modified-2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - token_table: - It maps token ID to a string. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 10 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - asr_datamodule = AsrDataModule(args) - aishell = AIShell(manifest_dir=args.manifest_dir) - test_cuts = aishell.test_cuts() - test_dl = asr_datamodule.test_dataloaders(test_cuts) - - test_sets = ["test"] - test_dls = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decoder.py b/egs/aishell/ASR/transducer_stateless_modified-2/decoder.py deleted file mode 120000 index bdfcea5c2..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/encoder_interface.py b/egs/aishell/ASR/transducer_stateless_modified-2/encoder_interface.py deleted file mode 120000 index a2a5f22cf..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py deleted file mode 100755 index 4f2c71d18..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ /dev/null @@ -1,246 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./transducer_stateless_modified-2/export.py \ - --exp-dir ./transducer_stateless_modified-2/exp \ - --epoch 89 \ - --avg 38 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `transducer_stateless_modified-2/decode.py`, -you can do:: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/aishell/ASR - ./transducer_stateless_modified-2/decode.py \ - --exp-dir ./transducer_stateless_modified-2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default=Path("transducer_stateless_modified-2/exp"), - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/joiner.py b/egs/aishell/ASR/transducer_stateless_modified-2/joiner.py deleted file mode 120000 index e9e435ecd..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/model.py b/egs/aishell/ASR/transducer_stateless_modified-2/model.py deleted file mode 100644 index 086957d0b..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/model.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -from typing import Optional - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - decoder_datatang: Optional[nn.Module] = None, - joiner_datatang: Optional[nn.Module] = None, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain - one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - decoder_datatang: - The decoder for the aidatatang_200zh dataset. - joiner_datatang: - The joiner for the aidatatang_200zh dataset. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - if decoder_datatang is not None: - assert hasattr(decoder_datatang, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.decoder_datatang = decoder_datatang - self.joiner_datatang = joiner_datatang - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - aishell: bool = True, - modified_transducer_prob: float = 0.0, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - modified_transducer_prob: - The probability to use modified transducer loss. - Returns: - Return the transducer loss. - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - sos_y_padded = sos_y_padded.to(torch.int64) - - if aishell: - decoder = self.decoder - joiner = self.joiner - else: - decoder = self.decoder_datatang - joiner = self.joiner_datatang - - decoder_out = decoder(sos_y_padded) - - # +1 here since a blank is prepended to each utterance. - logits = joiner( - encoder_out=encoder_out, - decoder_out=decoder_out, - encoder_out_len=x_lens, - decoder_out_len=y_lens + 1, - ) - - # rnnt_loss requires 0 padded targets - # Note: y does not start with SOS - y_padded = y.pad(mode="constant", padding_value=0) - - # We don't put this `import` at the beginning of the file - # as it is required only in the training, not during the - # reference stage - import optimized_transducer - - assert 0 <= modified_transducer_prob <= 1 - - if modified_transducer_prob == 0: - one_sym_per_frame = False - elif random.random() < modified_transducer_prob: - # random.random() returns a float in the range [0, 1) - one_sym_per_frame = True - else: - one_sym_per_frame = False - - loss = optimized_transducer.transducer_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="sum", - one_sym_per_frame=one_sym_per_frame, - from_log_softmax=False, - ) - - return loss diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py deleted file mode 100755 index 4a4e9237c..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ /dev/null @@ -1,326 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - -(1) greedy search -./transducer_stateless_modified-2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./transducer_stateless_modified-2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./transducer_stateless_modified-2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./transducer_stateless_modified-2/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lens = [f.size(0) for f in features] - feature_lens = torch.tensor(feature_lens, device=device) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) - - num_waves = encoder_out.size(0) - hyp_list = [] - logging.info(f"Using {params.method}") - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_list = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_list = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.method == "modified_beam_search": - hyp_list = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - hyp_list.append(hyp) - - hyps = [] - for hyp in hyp_list: - hyps.append([lexicon.token_table[i] for i in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/subsampling.py b/egs/aishell/ASR/transducer_stateless_modified-2/subsampling.py deleted file mode 120000 index 6fee09e58..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/test_decoder.py b/egs/aishell/ASR/transducer_stateless_modified-2/test_decoder.py deleted file mode 120000 index fbe1679ea..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/test_decoder.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified/test_decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py deleted file mode 100755 index 8fb7d1e49..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ /dev/null @@ -1,866 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# Copyright 2021 (Pingfeng Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./prepare.sh -./prepare_aidatatang_200zh.sh - -export CUDA_VISIBLE_DEVICES="0,1,2" - -./transducer_stateless_modified-2/train.py \ - --world-size 3 \ - --num-epochs 90 \ - --start-epoch 0 \ - --exp-dir transducer_stateless_modified-2/exp-2 \ - --max-duration 250 \ - --lr-factor 2.0 \ - --context-size 2 \ - --modified-transducer-prob 0.25 \ - --datatang-prob 0.2 -""" - - -import argparse -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from aidatatang_200zh import AIDatatang200zh -from aishell import AIShell -from asr_datamodule import AsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from model import Transducer -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless_modified-2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--modified-transducer-prob", - type=float, - default=0.25, - help="""The probability to use modified transducer loss. - In modified transduer, it limits the maximum number of symbols - per frame to 1. See also the option --max-sym-per-frame in - transducer_stateless/decode.py - """, - ) - - parser.add_argument( - "--datatang-prob", - type=float, - default=0.2, - help="The probability to select a batch from the aidatatang_200zh dataset", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - attention_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 800, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for Noam - "warm_step": 80000, # For the 100h subset, use 8k - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - decoder_datatang = get_decoder_model(params) - joiner_datatang = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - decoder_datatang=decoder_datatang, - joiner_datatang=joiner_datatang, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def is_aishell(c: Cut) -> bool: - """Return True if this cut is from the AIShell dataset. - - Note: - During data preparation, we set the custom field in - the supervision segment of aidatatang_200zh to - dict(origin='aidatatang_200zh') - See ../local/process_aidatatang_200zh.py. - """ - return c.supervisions[0].custom is None - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - aishell = is_aishell(supervisions["cut"][0]) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - loss = model( - x=feature, - x_lens=feature_lens, - y=y, - aishell=aishell, - modified_transducer_prob=params.modified_transducer_prob, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - datatang_train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - rng: random.Random, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - datatang_train_dl: - Dataloader for the aidatatang_200zh training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - aishell_tot_loss = MetricsTracker() - datatang_tot_loss = MetricsTracker() - tot_loss = MetricsTracker() - - # index 0: for LibriSpeech - # index 1: for GigaSpeech - # This sets the probabilities for choosing which datasets - dl_weights = [1 - params.datatang_prob, params.datatang_prob] - - iter_aishell = iter(train_dl) - iter_datatang = iter(datatang_train_dl) - - batch_idx = 0 - - while True: - idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] - dl = iter_aishell if idx == 0 else iter_datatang - - try: - batch = next(dl) - except StopIteration: - break - batch_idx += 1 - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - aishell = is_aishell(batch["supervisions"]["cut"][0]) - - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - if aishell: - aishell_tot_loss = ( - aishell_tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info - prefix = "aishell" # for logging only - else: - datatang_tot_loss = ( - datatang_tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info - prefix = "datatang" - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, {prefix}_loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"aishell_tot_loss[{aishell_tot_loss}], " - f"datatang_tot_loss[{datatang_tot_loss}], " - f"batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - if tb_writer is not None: - loss_info.write_summary( - tb_writer, - f"train/current_{prefix}_", - params.batch_idx_train, - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - aishell_tot_loss.write_summary( - tb_writer, "train/aishell_tot_", params.batch_idx_train - ) - datatang_tot_loss.write_summary( - tb_writer, "train/datatang_tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 12.0 - - return cuts - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - seed = 42 - fix_random_seed(seed) - rng = random.Random(seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - model.device = device - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - aishell = AIShell(manifest_dir=args.manifest_dir) - - train_cuts = aishell.train_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts) - - datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) - train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) - train_datatang_cuts = train_datatang_cuts.repeat(times=None) - - if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") - else: - cuts_musan = None - - asr_datamodule = AsrDataModule(args) - - train_dl = asr_datamodule.train_dataloaders( - train_cuts, - on_the_fly_feats=False, - cuts_musan=cuts_musan, - ) - - datatang_train_dl = asr_datamodule.train_dataloaders( - train_datatang_cuts, - on_the_fly_feats=False, - cuts_musan=cuts_musan, - ) - - valid_cuts = aishell.valid_cuts() - valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - - for dl in [ - train_dl, - # datatang_train_dl - ]: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - datatang_train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - datatang_train_dl=datatang_train_dl, - valid_dl=valid_dl, - rng=rng, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - assert 0 <= args.datatang_prob < 1, args.datatang_prob - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/transformer.py b/egs/aishell/ASR/transducer_stateless_modified-2/transformer.py deleted file mode 120000 index 4320d1105..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified-2/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless_modified/transformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/README.md b/egs/aishell/ASR/transducer_stateless_modified/README.md deleted file mode 100644 index 9709eb9a0..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/README.md +++ /dev/null @@ -1,21 +0,0 @@ -## Introduction - -The decoder, i.e., the prediction network, is from -https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 -(Rnn-Transducer with Stateless Prediction Network) - -You can use the following command to start the training: - -```bash -cd egs/aishell/ASR - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./transducer_stateless_modified/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir transducer_stateless_modified/exp \ - --max-duration 250 \ - --lr-factor 2.5 -``` diff --git a/egs/aishell/ASR/transducer_stateless_modified/__init__.py b/egs/aishell/ASR/transducer_stateless_modified/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell/ASR/transducer_stateless_modified/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified/asr_datamodule.py deleted file mode 120000 index a73848de9..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/beam_search.py b/egs/aishell/ASR/transducer_stateless_modified/beam_search.py deleted file mode 120000 index e188617a8..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/conformer.py b/egs/aishell/ASR/transducer_stateless_modified/conformer.py deleted file mode 120000 index 8be0dc864..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py deleted file mode 100755 index 56f3724eb..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ /dev/null @@ -1,518 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./transducer_stateless_modified/decode.py \ - --epoch 14 \ - --avg 7 \ - --exp-dir ./transducer_stateless_modified/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./transducer_stateless_modified/decode.py \ - --epoch 14 \ - --avg 7 \ - --exp-dir ./transducer_stateless_modified/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./transducer_stateless_modified/decode.py \ - --epoch 14 \ - --avg 7 \ - --exp-dir ./transducer_stateless_modified/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./transducer_stateless_modified/decode.py \ - --epoch 14 \ - --avg 7 \ - --exp-dir ./transducer_stateless_modified/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless_modified/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - token_table: - It maps token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 10 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - - test_sets = ["test"] - test_dls = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/decoder.py b/egs/aishell/ASR/transducer_stateless_modified/decoder.py deleted file mode 120000 index 82337f7ef..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/encoder_interface.py b/egs/aishell/ASR/transducer_stateless_modified/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py deleted file mode 100755 index 487748947..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ /dev/null @@ -1,245 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./transducer_stateless_modified/export.py \ - --exp-dir ./transducer_stateless_modified/exp \ - --epoch 64 \ - --avg 33 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `transducer_stateless_modified/decode.py`, -you can do:: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/aishell/ASR - ./transducer_stateless_modified/decode.py \ - --exp-dir ./transducer_stateless_modified/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import AttributeDict, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default=Path("transducer_stateless_modified/exp"), - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/joiner.py b/egs/aishell/ASR/transducer_stateless_modified/joiner.py deleted file mode 120000 index 1aec6bfaf..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/model.py b/egs/aishell/ASR/transducer_stateless_modified/model.py deleted file mode 120000 index 16ddd93f0..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py deleted file mode 100755 index 66a91709e..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ /dev/null @@ -1,326 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - -(1) greedy search -./transducer_stateless_modified/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./transducer_stateless_modified/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./transducer_stateless_modified/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./transducer_stateless_modified/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --lang-dir /path/to/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lens = [f.size(0) for f in features] - feature_lens = torch.tensor(feature_lens, device=device) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) - - num_waves = encoder_out.size(0) - hyp_list = [] - logging.info(f"Using {params.method}") - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_list = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_list = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.method == "modified_beam_search": - hyp_list = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - hyp_list.append(hyp) - - hyps = [] - for hyp in hyp_list: - hyps.append([lexicon.token_table[i] for i in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/subsampling.py b/egs/aishell/ASR/transducer_stateless_modified/subsampling.py deleted file mode 120000 index 6fee09e58..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/aishell/ASR/transducer_stateless_modified/test_decoder.py b/egs/aishell/ASR/transducer_stateless_modified/test_decoder.py deleted file mode 100755 index fe0bdee70..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/test_decoder.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/aishell/ASR - python ./transducer_stateless/test_decoder.py -""" - -import torch -from decoder import Decoder - - -def test_decoder(): - vocab_size = 3 - blank_id = 0 - embedding_dim = 128 - context_size = 4 - - decoder = Decoder( - vocab_size=vocab_size, - embedding_dim=embedding_dim, - blank_id=blank_id, - context_size=context_size, - ) - N = 100 - U = 20 - x = torch.randint(low=0, high=vocab_size, size=(N, U)) - y = decoder(x) - assert y.shape == (N, U, embedding_dim) - - # for inference - x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) - y = decoder(x, need_pad=False) - assert y.shape == (N, 1, embedding_dim) - - -def main(): - test_decoder() - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py deleted file mode 100755 index 5f116f2bd..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ /dev/null @@ -1,746 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# Copyright 2021 (Pingfeng Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2" - -./transducer_stateless_modified/train.py \ - --world-size 3 \ - --num-epochs 65 \ - --start-epoch 0 \ - --exp-dir transducer_stateless_modified/exp \ - --max-duration 250 \ - --lr-factor 2.0 \ - --context-size 2 \ - --modified-transducer-prob 0.25 -""" - - -import argparse -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from model import Transducer -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless_modified/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--modified-transducer-prob", - type=float, - default=0.25, - help="""The probability to use modified transducer loss. - In modified transduer, it limits the maximum number of symbols - per frame to 1. See also the option --max-sym-per-frame in - transducer_stateless/decode.py - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - attention_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 800, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for Noam - "warm_step": 80000, # For the 100h subset, use 8k - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - loss = model( - x=feature, - x_lens=feature_lens, - y=y, - modified_transducer_prob=params.modified_transducer_prob, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - oov="", - ) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - aishell = AishellAsrDataModule(args) - train_cuts = aishell.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 12.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_dl = aishell.train_dataloaders(train_cuts) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/transformer.py b/egs/aishell/ASR/transducer_stateless_modified/transformer.py deleted file mode 120000 index 214afed39..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/transformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/asr_datamodule.py b/egs/aishell/ASR/whisper/asr_datamodule.py deleted file mode 120000 index fa1b8cca3..000000000 --- a/egs/aishell/ASR/whisper/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py deleted file mode 100755 index 5350cb2b0..000000000 --- a/egs/aishell/ASR/whisper/decode.py +++ /dev/null @@ -1,507 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Wei Kang) -# 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -# Command for decoding using fine-tuned models: -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper -ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch 999 --avg 1 \ - --manifest-dir data/fbank_whisper \ - --beam-size 10 --max-duration 50 - -# Command for decoding using pretrained models (before fine-tuning): - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch -1 --avg 1 \ - --manifest-dir data/fbank_whisper \ - --remove-whisper-encoder-input-length-restriction False \ - --beam-size 10 --max-duration 50 - -""" - -import argparse -import logging -import re -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -import whisper -from asr_datamodule import AishellAsrDataModule -from tn.chinese.normalizer import Normalizer -from whisper.normalizers import BasicTextNormalizer -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from zhconv import convert - -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def average_checkpoints( - filenames: List[Path], device: torch.device = torch.device("cpu") -) -> dict: - """Average a list of checkpoints. - The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. - - Args: - filenames: - Filenames of the checkpoints to be averaged. We assume all - checkpoints are saved by :func:`save_checkpoint`. - device: - Move checkpoints to this device before averaging. - Returns: - Return a dict (i.e., state_dict) which is the average of all - model state dicts contained in the checkpoints. - """ - n = len(filenames) - - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] - else: - avg = torch.load(filenames[0], map_location=device) - - # Identify shared parameters. Two parameters are said to be shared - # if they have the same data_ptr - uniqued: Dict[int, str] = dict() - - for k, v in avg.items(): - v_data_ptr = v.data_ptr() - if v_data_ptr in uniqued: - continue - uniqued[v_data_ptr] = k - - uniqued_names = list(uniqued.values()) - - for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] - else: - state_dict = torch.load(filenames[i], map_location=device) - for k in uniqued_names: - avg[k] += state_dict[k] - - for k in uniqued_names: - if avg[k].is_floating_point(): - avg[k] /= n - else: - avg[k] //= n - - return avg - - -def remove_punctuation(text: str or List[str]): - """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py - - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings without any punctuation. - """ - punctuation = "!,.;:?、!,。;:?《》 " - if isinstance(text, str): - text = re.sub(r"[{}]+".format(punctuation), "", text).strip() - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = re.sub(r"[{}]+".format(punctuation), "", t).strip() - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type {type(text)}") - - -def to_simple(text: str or List[str]): - """Convert traditional Chinese to simplified Chinese. - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings converted to simplified Chinese. - """ - if isinstance(text, str): - text = convert(text, "zh-cn") - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = convert(t, "zh-cn") - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type{type(text)}") - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=-1, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="beam-search", - help="""Decoding method. - Supported values are: - - beam-search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=1, - help="beam size for beam search decoding", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="whisper/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], - help="""The model name to use. - """, - ) - - parser.add_argument( - "--remove-whisper-encoder-input-length-restriction", - type=str2bool, - default=True, - help="replace whisper encoder forward method to remove input length restriction", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: "beam-search" - - value: A list of lists. Each sublist is a list of token IDs. - Args: - params: - It is returned by :func:`get_params`. - model: - The neural model. - batch: - It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. - Returns: - Return a dict, whose key may be "beam-search". - """ - dtype = torch.float16 - device = torch.device("cuda") - - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device, dtype=dtype).transpose(1, 2) - if not params.remove_whisper_encoder_input_length_restriction: - T = 3000 - if feature.shape[2] < T: - feature = torch.cat( - [ - feature, - torch.zeros( - feature.shape[0], feature.shape[1], T - feature.shape[2] - ).to(device, dtype=dtype), - ], - 2, - ) - - supervisions = batch["supervisions"] - feature_len = supervisions["num_frames"] - feature_len = feature_len.to(device, dtype=dtype) - results = model.decode(feature, params.decoding_options) - hyps = [result.text for result in results] - - hyps = remove_punctuation(hyps) - hyps = to_simple(hyps) - hyps = [params.normalizer.normalize(hyp) for hyp in hyps] - - return {"beam-search": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - The dataloader. - params: - It is returned by :func:`get_params`. - model: - The neural model. - Returns: - Return a dict, whose key may be "beam-search". - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - batch=batch, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results_char, - enable_log=enable_log, - compute_CER=True, - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" - ) - - options = whisper.DecodingOptions( - task="transcribe", - language="zh", - without_timestamps=True, - beam_size=params.beam_size, - ) - params.decoding_options = options - params.cleaner = BasicTextNormalizer() - params.normalizer = Normalizer() - - logging.info("Decoding started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda") - - logging.info(f"device: {device}") - - if params.remove_whisper_encoder_input_length_restriction: - replace_whisper_encoder_forward() - model = whisper.load_model(params.model_name, "cpu") - if params.epoch > 0: - if params.avg > 1: - start = params.epoch - params.avg - assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - # deepspeed converted checkpoint only contains model state_dict - filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" - for epoch in range(start, params.epoch + 1) - ] - model.load_state_dict(average_checkpoints(filenames)) - else: - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - # save checkpoints - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(model.state_dict(), filename) - else: - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - model.load_state_dict(checkpoint, strict=True) - else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - test_dl = aishell.test_dataloaders(aishell.test_cuts()) - test_sets = ["valid", "test"] - test_dls = [valid_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json deleted file mode 100644 index bf8cc0452..000000000 --- a/egs/aishell/ASR/whisper/ds_config_zero1.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "fp16": { - "enabled": true, - "loss_scale": 0, - "loss_scale_window": 100, - "initial_scale_power": 16, - "hysteresis": 2, - "min_loss_scale": 0.01 - }, - "zero_optimization": { - "stage": 1, - "allgather_partitions": true, - "allgather_bucket_size": 2e8, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 2e8, - "contiguous_gradients": true - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-5 - } - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": 0, - "warmup_max_lr": 1e-5, - "warmup_num_steps": 100 - } - }, - "gradient_accumulation_steps": 1, - "gradient_clipping": 5, - "steps_per_print": 50, - "train_micro_batch_size_per_gpu": 1, - "wall_clock_breakdown": false -} diff --git a/egs/aishell/ASR/whisper/label_smoothing.py b/egs/aishell/ASR/whisper/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/aishell/ASR/whisper/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/optim.py b/egs/aishell/ASR/whisper/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/aishell/ASR/whisper/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt deleted file mode 100755 index 0708f2344..000000000 --- a/egs/aishell/ASR/whisper/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -k2 -kaldialign -git+https://github.com/lhotse-speech/lhotse -sentencepiece -tensorboard -librosa -git+https://github.com/yuekaizhang/whisper.git -zhconv -WeTextProcessing -deepspeed diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py deleted file mode 100755 index d77f8c270..000000000 --- a/egs/aishell/ASR/whisper/train.py +++ /dev/null @@ -1,927 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -#fine-tuning with deepspeed zero stage 1 -torchrun --nproc_per_node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --manifest-dir data/fbank_whisper \ - --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json - -# fine-tuning with ddp -torchrun --nproc_per_node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_medium \ - --manifest-dir data/fbank_whisper \ - --base-lr 1e-5 \ - --model-name medium -""" - - -import argparse -import copy -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import deepspeed -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import whisper -from asr_datamodule import AishellAsrDataModule -from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict -from label_smoothing import LabelSmoothingLoss -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.functional import pad as pad_tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import update_averaged_model -from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=10, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="whisper/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], - help="""The model name to use. - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=1e-5, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - parser = deepspeed.add_config_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - frame_shift_ms: The frame shift in milliseconds. - - allowed_excess_duration_ratio: The allowed excess duration ratio. - - best_train_loss: The best training loss so far. - - best_valid_loss: The best validation loss so far. - - best_train_epoch: The epoch where the best training loss is achieved. - - best_valid_epoch: The epoch where the best validation loss is achieved. - - batch_idx_train: The batch index of the current batch. - - log_interval: Log training stats every `log_interval` batches. - - reset_interval: Reset the stats every `reset_interval` batches. - - valid_interval: Run validation every `valid_interval` batches. - - env_info: The environment information. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "subsampling_factor": 2, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 5000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute the loss for the given batch. - Args: - params: - It is returned by :func:`get_params`. - tokenizer: - The tokenizer used to encode the text. - model: - The model for training. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - Whether it is training. - Returns: - Return a tuple of two elements. The first element is the loss tensor. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - - def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: - padding_size = max(tensor.shape[0] for tensor in tensors) - dims = len(tensors[0].shape) - padded_tensors = [] - for tensor in tensors: - padding = [0] * 2 * dims - padding[-1] = padding_size - tensor.shape[0] - padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) - return torch.stack([tensor for tensor in padded_tensors], dim=0) - - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - - assert feature.ndim == 3 - feature = feature.to(device) - feature = feature.transpose(1, 2) # (N, C, T) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - - texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [text.replace(" ", "") for text in texts] - - text_tokens_list = [ - list(tokenizer.sot_sequence_including_notimestamps) - + tokenizer.encode(text) - + [tokenizer.eot] - for text in texts - ] - # convert it to torch tensor - text_tokens_list = [ - torch.LongTensor(text_tokens) for text_tokens in text_tokens_list - ] - - # 50256 is the index of for all whisper models - prev_outputs_tokens = _batch_tensors( - [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 - ) - target_tokens = _batch_tensors( - [tokens[1:] for tokens in text_tokens_list], pad_value=50256 - ) - target_lengths = torch.LongTensor( - [tokens.shape[0] - 1 for tokens in text_tokens_list] - ) - - decoder_criterion = LabelSmoothingLoss( - ignore_index=50256, label_smoothing=0.1, reduction="sum" - ) - - # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> - ignore_prefix_size = 3 - with torch.set_grad_enabled(is_training): - encoder_out = model.encoder(feature) - text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) - text_logits = text_logits[:, ignore_prefix_size:, :] - target_tokens = target_tokens[:, ignore_prefix_size:] - loss = decoder_criterion(text_logits, target_tokens.to(device)) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - tokenizer=tokenizer, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - if params.deepspeed: - # deepspeed's backward() is different from torch's backward() - # in that it does not accept a loss tensor as input. - # It computes the loss internally. - model.backward(loss) - model.step() - else: - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - and not params.deepspeed - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - if batch_idx % params.log_interval == 0: - try: - cur_lr = scheduler.get_last_lr()[0] - except: # noqa - cur_lr = 0.0 - cur_grad_scale = ( - scaler._scale.item() - if (params.use_fp16 and not params.deepspeed) - else 1.0 - ) - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if (params.use_fp16 and not params.deepspeed) - else "" - ) - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info(params) - - logging.info("About to create model") - - replace_whisper_encoder_forward() - model = whisper.load_model(params.model_name, "cpu") - del model.alignment_heads - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - tokenizer = whisper.tokenizer.get_tokenizer( - model.is_multilingual, - num_languages=model.num_languages, - language="zh", - task="transcribe", - ) - - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - else: - device = torch.device("cpu") - logging.info(f"Device: {device}") - model.to(device) - - optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if world_size > 1: - if params.deepspeed: - logging.info("Using DeepSpeed") - model, optimizer, _, scheduler = deepspeed.initialize( - args=params, model=model, model_parameters=model.parameters() - ) - else: - logging.info("Using DDP") - setup_dist(use_ddp_launch=True) - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - aishell = AishellAsrDataModule(args) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - logging.info(f"start training from epoch {params.start_epoch}") - for epoch in range(params.start_epoch, params.num_epochs + 1): - if not params.deepspeed: - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - tokenizer=tokenizer, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if params.deepspeed: - model.save_checkpoint( - save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", - client_state={}, - ) - if rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", - ) - else: - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1 and not params.deepspeed: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = get_world_size() - rank = get_rank() - - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - run(rank=rank, world_size=world_size, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py deleted file mode 100644 index 5bfbdce3b..000000000 --- a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -import torch.nn.functional as F -import whisper - - -def forward(self, x: torch.Tensor): - """ - x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) - the mel spectrogram of the audio - """ - x = F.gelu(self.conv1(x)) - x = F.gelu(self.conv2(x)) - x = x.permute(0, 2, 1) - - x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype) - - for block in self.blocks: - x = block(x) - - x = self.ln_post(x) - return x - - -def replace_whisper_encoder_forward(): - """ - This function monkey patches the forward method of the whisper encoder. - To be called before the model is loaded, it changes whisper to process audio with any length < 30s. - """ - whisper.model.AudioEncoder.forward = forward diff --git a/egs/aishell/ASR/zipformer/__init__.py b/egs/aishell/ASR/zipformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell/ASR/zipformer/asr_datamodule.py b/egs/aishell/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/aishell/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/beam_search.py b/egs/aishell/ASR/zipformer/beam_search.py deleted file mode 120000 index 8554e44cc..000000000 --- a/egs/aishell/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/decode.py b/egs/aishell/ASR/zipformer/decode.py deleted file mode 100755 index 538189e52..000000000 --- a/egs/aishell/ASR/zipformer/decode.py +++ /dev/null @@ -1,818 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (trivial_graph) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(4) fast beam search (LG) -./zipformer/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}_" + key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list("".join(text.split())) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level=True) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - compute_CER=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest_oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - dev_cuts = aishell.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = aishell.valid_dataloaders(dev_cuts) - - test_cuts = aishell.test_cuts() - test_cuts = test_cuts.filter(remove_short_utt) - test_dl = aishell.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/zipformer/decode_bbpe.py b/egs/aishell/ASR/zipformer/decode_bbpe.py deleted file mode 100755 index 1ec10b059..000000000 --- a/egs/aishell/ASR/zipformer/decode_bbpe.py +++ /dev/null @@ -1,840 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Mingshuang Luo, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode_bbpe.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp_bbpe \ - --lang-dir data/lang_bbpe_500 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./zipformer/decode_bbpe.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp_bbpe \ - --lang-dir data/lang_bbpe_500 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (trivial_graph) -./zipformer/decode_bbpe.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp_bbpe \ - --lang-dir data/lang_bbpe_500 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(4) fast beam search (LG) -./zipformer/decode_bbpe.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp_bbpe \ - --lang-dir data/lang_bbpe_500 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest oracle WER) -./zipformer/decode_bbpe.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp_bbpe \ - --lang-dir data/lang_bbpe_500 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer_bbpe/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the byte BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bbpe_500/", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - lexicon: Lexicon, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - hyps.append([lexicon.word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - ref_texts = [] - for tx in supervisions["text"]: - ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) - - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(ref_texts), - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}_" + key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - lexicon: - directory containing the lexicon. - sp: - SentencePiece model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - lexicon=lexicon, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = "".join(ref_text.split()) - - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest_oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bbpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - lexicon = Lexicon(params.lang_dir) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell = AishellAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - dev_cuts = aishell.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = aishell.valid_dataloaders(dev_cuts) - - test_cuts = aishell.test_cuts() - test_cuts = test_cuts.filter(remove_short_utt) - test_dl = aishell.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/zipformer/decode_stream.py b/egs/aishell/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/aishell/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/decoder.py b/egs/aishell/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/aishell/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/encoder_interface.py b/egs/aishell/ASR/zipformer/encoder_interface.py deleted file mode 120000 index b9aa0ae08..000000000 --- a/egs/aishell/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export-onnx-streaming.py b/egs/aishell/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/aishell/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export-onnx.py b/egs/aishell/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/aishell/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export.py b/egs/aishell/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/aishell/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/jit_pretrained.py b/egs/aishell/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/aishell/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py deleted file mode 100755 index cd16284f7..000000000 --- a/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py +++ /dev/null @@ -1,279 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer_bbpe/exp \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -./zipformer/jit_pretrained.py \ - --nn-model-filename ./zipformer_bbpe/exp/cpu_jit.pt \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall import smart_byte_decode - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--bpe-model", - type=str, - required=True, - help="""Path to the bbpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = model.decoder.blank_id - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - features=features, - feature_lengths=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = smart_byte_decode(sp.decode(hyp)) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py b/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 120000 index 1962351e9..000000000 --- a/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/joiner.py b/egs/aishell/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/aishell/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/model.py b/egs/aishell/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/aishell/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_check.py b/egs/aishell/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/aishell/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_decode.py b/egs/aishell/ASR/zipformer/onnx_decode.py deleted file mode 100755 index 17c6eceb4..000000000 --- a/egs/aishell/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX exported models and uses them to decode the test sets. -""" - -import argparse -import logging -import time -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from lhotse.cut import Cut -from onnx_pretrained import OnnxModel, greedy_search - -from icefall.utils import setup_logger, store_transcripts, write_error_stats - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - - return parser - - -def decode_one_batch( - model: OnnxModel, token_table: k2.SymbolTable, batch: dict -) -> List[List[str]]: - """Decode one batch and return the result. - Currently it only greedy_search is supported. - - Args: - model: - The neural model. - token_table: - Mapping ids to tokens. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - - Returns: - Return the decoded results for each utterance. - """ - feature = batch["inputs"] - assert feature.ndim == 3 - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(dtype=torch.int64) - - encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) - - hyps = greedy_search( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) - - hyps = [[token_table[h] for h in hyp] for hyp in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: nn.Module, - token_table: k2.SymbolTable, -) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - The neural model. - token_table: - Mapping ids to tokens. - - Returns: - - A list of tuples. Each tuple contains three elements: - - cut_id, - - reference transcript, - - predicted result. - - The total duration (in seconds) of the dataset. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 10 - total_duration = 0 - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = list(ref_text) - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return results, total_duration - - -def save_results( - res_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -): - recog_path = res_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - errs_info = res_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("WER", file=f) - print(wer, file=f) - - s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - assert ( - args.decoding_method == "greedy_search" - ), "Only supports greedy_search currently." - res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" - - setup_logger(f"{res_dir}/log-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - logging.info(f"Device: {device}") - - token_table = k2.SymbolTable.from_file(args.tokens) - assert token_table[0] == "" - - logging.info(vars(args)) - - logging.info("About to create model") - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - # we need cut ids to display recognition results. - args.return_cuts = True - - aishell = AishellAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - dev_cuts = aishell.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = aishell.valid_dataloaders(dev_cuts) - - test_cuts = aishell.test_net_cuts() - test_cuts = test_cuts.filter(remove_short_utt) - test_dl = aishell.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - start_time = time.time() - results, total_duration = decode_dataset( - dl=test_dl, model=model, token_table=token_table - ) - end_time = time.time() - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - - logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") - logging.info(f"Wave duration: {total_duration:.3f} s") - logging.info( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - - save_results(res_dir=res_dir, test_set_name=test_set, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py b/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 120000 index cfea104c2..000000000 --- a/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_pretrained.py b/egs/aishell/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/aishell/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/optim.py b/egs/aishell/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/aishell/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/pretrained.py b/egs/aishell/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/aishell/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/pretrained_bbpe.py b/egs/aishell/ASR/zipformer/pretrained_bbpe.py deleted file mode 100755 index 387bef98a..000000000 --- a/egs/aishell/ASR/zipformer/pretrained_bbpe.py +++ /dev/null @@ -1,403 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp_bbpe \ - --tokens ./data/lang_bbpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp_bbpe \ - --causal 1 \ - --tokens ./data/lang_bbpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained_bbpe.py \ - --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained_bbpe.py \ - --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained_bbpe.py \ - --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained_bbpe.py \ - --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained_bbpe.py \ - --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained_bbpe.py \ - --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe ./data/lang_bbpe_500/bbpe.model \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - - -You can also use `./zipformer/exp_bbpe/epoch-xx.pt`. - -Note: ./zipformer/exp_bbpe/pretrained.pt is generated by ./zipformer/export_bbpe.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall import smart_byte_decode - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - required=True, - help="""Path to the bbpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/zipformer/scaling.py b/egs/aishell/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/aishell/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/scaling_converter.py b/egs/aishell/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/aishell/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/streaming_beam_search.py b/egs/aishell/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/aishell/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py deleted file mode 100755 index 6a7ef2750..000000000 --- a/egs/aishell/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,881 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -from asr_datamodule import AishellAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="Path to the lang dir(containing lexicon, tokens, etc.)", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, - encoder_out=encoder_out, - streams=decode_streams, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - lexicon: - The Lexicon. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - list(decode_streams[i].ground_truth.strip()), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - key = f"greedy_search_{key}" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_{key}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}_{key}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - aishell = AishellAsrDataModule(args) - - dev_cuts = aishell.valid_cuts() - test_cuts = aishell.test_cuts() - - test_sets = ["dev", "test"] - test_cuts = [dev_cuts, test_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/zipformer/subsampling.py b/egs/aishell/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/aishell/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py deleted file mode 100755 index dddfe52fa..000000000 --- a/egs/aishell/ASR/zipformer/train.py +++ /dev/null @@ -1,1349 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 12 \ - --start-epoch 1 \ - --exp-dir zipformer/exp \ - --training-subset L - --lr-epochs 1.5 \ - --max-duration 350 - -# For mix precision training: - -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 12 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --training-subset L \ - --lr-epochs 1.5 \ - --max-duration 750 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="""Maximum left-contexts for causal training, measured in frames which will - be converted to a number of chunks. If splitting into chunks, - chunk left-context frames will be chosen randomly from this list; else not relevant.""", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="""The prune range for rnnt loss, it means how many symbols(context) - we are using to compute the loss""", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="""The scale to smooth the loss with lm - (output of prediction network) part.""", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="""The scale to smooth the loss with am (output of encoder network) part.""", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="""To get pruning ranges, we will calculate a simple version - loss(joiner is just addition), this simple loss also uses for - training (as a regularization item). We will scale the simple loss - with this parameter before adding to the final loss.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss = losses[:2] - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - aishell = AishellAsrDataModule(args) - - train_cuts = aishell.train_cuts() - valid_cuts = aishell.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15 seconds - # - # Caution: There is a reason to select 15.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 12.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_dl = aishell.valid_dataloaders(valid_cuts) - - if False and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - graph_compiler: - The compiler to encode texts to ids. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py deleted file mode 100755 index dbc262c5c..000000000 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ /dev/null @@ -1,941 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./zipformer/train_bbpe.py \ - --world-size 8 \ - --num-epochs 12 \ - --start-epoch 1 \ - --exp-dir zipformer/exp_bbpe \ - --max-duration 350 - -# For mix precision training: - -./zipformer/train_bbpe.py \ - --world-size 8 \ - --num-epochs 12 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp_bbpe \ - --max-duration 750 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from typing import Optional, Tuple, Union - -import k2 -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from train import ( - LRSchedulerType, - add_model_arguments, - get_adjusted_batch_count, - get_model, - get_params, - load_checkpoint_if_available, - save_checkpoint, - set_batch_count, -) - -from icefall import byte_encode, diagnostics -from icefall.checkpoint import remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, - tokenize_by_CJK_char, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer_bbpe/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the Byte BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="""The prune range for rnnt loss, it means how many symbols(context) - we are using to compute the loss""", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="""The scale to smooth the loss with lm - (output of prediction network) part.""", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="""The scale to smooth the loss with am (output of encoder network) part.""", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="""To get pruning ranges, we will calculate a simple version - loss(joiner is just addition), this simple loss also uses for - training (as a regularization item). We will scale the simple loss - with this parameter before adding to the final loss.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss = losses[:2] - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bbpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - aishell = AishellAsrDataModule(args) - - train_cuts = aishell.train_cuts() - valid_cuts = aishell.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15 seconds - # - # Caution: There is a reason to select 15.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 15.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - def tokenize_and_encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = byte_encode(tokenize_by_CJK_char(text)) - c.supervisions[0].text = text - return c - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_cuts = train_cuts.map(tokenize_and_encode_text) - - valid_cuts = valid_cuts.map(tokenize_and_encode_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_dl = aishell.valid_dataloaders(valid_cuts) - - if False and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The sentence piece model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/aishell/ASR/zipformer/zipformer.py b/egs/aishell/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/aishell/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/aishell2/ASR/README.md b/egs/aishell2/ASR/README.md deleted file mode 100644 index 4e786af11..000000000 --- a/egs/aishell2/ASR/README.md +++ /dev/null @@ -1,23 +0,0 @@ - -# Introduction - -This recipe contains various different ASR models trained with Aishell2. - -In AISHELL-2, 1000 hours of clean read-speech data from iOS is published, which is free for academic usage. On top of AISHELL-2 corpus, an improved recipe is developed and released, containing key components for industrial applications, such as Chinese word segmentation, flexible vocabulary expension and phone set transformation etc. Pipelines support various state-of-the-art techniques, such as time-delayed neural networks and Lattic-Free MMI objective funciton. In addition, we also release dev and test data from other channels (Android and Mic). - -(From [AISHELL-2: Transforming Mandarin ASR Research Into Industrial Scale](https://arxiv.org/abs/1808.10583)) - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|-----------------------------| -| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless5 in librispeech recipe | - -The decoder in `transducer_stateless` is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md deleted file mode 100644 index 0b7ae9299..000000000 --- a/egs/aishell2/ASR/RESULTS.md +++ /dev/null @@ -1,87 +0,0 @@ -## Results - -### Aishell2 char-based training results - -#### Pruned transducer stateless 5 - -Using the codes from this commit https://github.com/k2-fsa/icefall/pull/465. - -When training with context size equals to 1, the WERs are - -| | dev-ios | test-ios | comment | -|------------------------------------|-------|----------|----------------------------------| -| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 | -| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 | - -The training command for reproducing is given below: - -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --lang-dir data/lang_char \ - --num-epochs 40 \ - --start-epoch 1 \ - --exp-dir /result \ - --max-duration 300 \ - --use-fp16 0 \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --context-size 1 -``` - -The decoding command is: -```bash -for method in greedy_search modified_beam_search fast_beam_search fast_beam_search_nbest fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do - ./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method $method \ - --max-sym-per-frame 1 \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --context-size 1 \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 \ - --context-size 1 \ - --use-averaged-model True -done -``` -The tensorboard training log can be found at -https://tensorboard.dev/experiment/RXyX4QjQQVKjBS2eQ2Qajg/#scalars - -A pre-trained model and decoding logs can be found at - -When training with context size equals to 2, the WERs are - -| | dev-ios | test-ios | comment | -|------------------------------------|-------|----------|----------------------------------| -| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 | -| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 | - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars - -A pre-trained model and decoding logs can be found at diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell2/ASR/local/compile_lg.py b/egs/aishell2/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/aishell2/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py deleted file mode 100755 index 557f22b0c..000000000 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aishell2 dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_aishell2( - num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False -): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(8, os.cpu_count()) - - dataset_parts = ( - "train", - "dev", - "test", - ) - prefix = "aishell2" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info("Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_aishell2( - num_mel_bins=args.num_mel_bins, - perturb_speed=args.perturb_speed, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/aishell2/ASR/local/compute_fbank_musan.py b/egs/aishell2/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/aishell2/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/aishell2/ASR/local/display_manifest_statistics.py b/egs/aishell2/ASR/local/display_manifest_statistics.py deleted file mode 100755 index 14844cbf3..000000000 --- a/egs/aishell2/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer_stateless/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - paths = [ - "./data/fbank/aishell2_cuts_train.jsonl.gz", - "./data/fbank/aishell2_cuts_dev.jsonl.gz", - "./data/fbank/aishell2_cuts_test.jsonl.gz", - ] - - for path in paths: - print(f"Starting display the statistics for {path}") - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Starting display the statistics for ./data/fbank/aishell2_cuts_train.jsonl.gz -Cuts count: 3026106 -Total duration (hours): 3021.2 -Speech duration (hours): 3021.2 (100.0%) -*** -Duration statistics (seconds): -mean 3.6 -std 1.5 -min 0.3 -25% 2.4 -50% 3.3 -75% 4.4 -99% 8.2 -99.5% 8.9 -99.9% 10.6 -max 21.5 -Starting display the statistics for ./data/fbank/aishell2_cuts_dev.jsonl.gz -Cuts count: 2500 -Total duration (hours): 2.0 -Speech duration (hours): 2.0 (100.0%) -*** -Duration statistics (seconds): -mean 2.9 -std 1.0 -min 1.1 -25% 2.2 -50% 2.7 -75% 3.4 -99% 6.3 -99.5% 6.7 -99.9% 7.8 -max 9.4 -Starting display the statistics for ./data/fbank/aishell2_cuts_test.jsonl.gz -Cuts count: 5000 -Total duration (hours): 4.0 -Speech duration (hours): 4.0 (100.0%) -*** -Duration statistics (seconds): -mean 2.9 -std 1.0 -min 1.1 -25% 2.2 -50% 2.7 -75% 3.3 -99% 6.2 -99.5% 6.6 -99.9% 7.7 -max 8.5 -""" diff --git a/egs/aishell2/ASR/local/prepare_char.py b/egs/aishell2/ASR/local/prepare_char.py deleted file mode 120000 index 8779181e5..000000000 --- a/egs/aishell2/ASR/local/prepare_char.py +++ /dev/null @@ -1 +0,0 @@ -../../../aidatatang_200zh/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/aishell2/ASR/local/prepare_lang.py b/egs/aishell2/ASR/local/prepare_lang.py deleted file mode 120000 index 5d88dc1c8..000000000 --- a/egs/aishell2/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/aishell2/ASR/local/prepare_words.py b/egs/aishell2/ASR/local/prepare_words.py deleted file mode 120000 index e58fabb8f..000000000 --- a/egs/aishell2/ASR/local/prepare_words.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech/ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/aishell2/ASR/local/text2segments.py b/egs/aishell2/ASR/local/text2segments.py deleted file mode 120000 index 7d68a39c3..000000000 --- a/egs/aishell2/ASR/local/text2segments.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech/ASR/local/text2segments.py \ No newline at end of file diff --git a/egs/aishell2/ASR/local/text2token.py b/egs/aishell2/ASR/local/text2token.py deleted file mode 120000 index 81e459d69..000000000 --- a/egs/aishell2/ASR/local/text2token.py +++ /dev/null @@ -1 +0,0 @@ -../../../aidatatang_200zh/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh deleted file mode 100755 index c959bd4d1..000000000 --- a/egs/aishell2/ASR/prepare.sh +++ /dev/null @@ -1,196 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=30 -stage=0 -stop_stage=7 -perturb_speed=true - - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, you need to apply aishell2 through -# their official website. -# https://www.aishelltech.com/aishell_2 -# -# - $dl_dir/aishell2 -# -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" - - # If you have pre-downloaded it to /path/to/aishell2, - # you can create a symlink - # - # ln -sfv /path/to/aishell2 $dl_dir/aishell2 - # - # The directory structure is - # aishell2/ - # |-- AISHELL-2 - # | |-- iOS - # |-- data - # |-- wav - # |-- trans.txt - # |-- dev - # |-- wav - # |-- trans.txt - # |-- test - # |-- wav - # |-- trans.txt - - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/musan - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare aishell2 manifest" - # We assume that you have downloaded and unzip the aishell2 corpus - # to $dl_dir/aishell2 - if [ ! -f data/manifests/.aishell2_manifests.done ]; then - mkdir -p data/manifests - lhotse prepare aishell2 $dl_dir/aishell2 data/manifests -j $nj - touch data/manifests/.aishell2_manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for aishell2" - if [ ! -f data/fbank/.aishell2.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} - touch data/fbank/.aishell2.done - fi -fi - -whisper_mel_bins=80 -if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then - log "Stage 30: Compute whisper fbank for aishell2" - if [ ! -f data/fbank/.aishell2.whisper.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.aishell2.whisper.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -lang_char_dir=data/lang_char -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare char based lang" - mkdir -p $lang_char_dir - - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - # 2. chmod +x ./jq - # 3. cp jq /usr/bin - if [ ! -f $lang_char_dir/text ]; then - gunzip -c data/manifests/aishell2_supervisions_train.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text - fi - - # The implementation of chinese word segmentation for text, - # and it will take about 15 minutes. - # If you can't install paddle-tiny with python 3.8, please refer to - # https://github.com/fxsjy/jieba/issues/920 - if [ ! -f $lang_char_dir/text_words_segmentation ]; then - python3 ./local/text2segments.py \ - --input-file $lang_char_dir/text \ - --output-file $lang_char_dir/text_words_segmentation - fi - - cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt - - if [ ! -f $lang_char_dir/words.txt ]; then - python3 ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt - fi - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - python3 ./local/prepare_char.py - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text $lang_char_dir/text_words_segmentation \ - -lm $lang_char_dir/3-gram.unpruned.arpa - fi - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building LG - python3 -m kaldilm \ - --read-symbol-table="$lang_char_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compile LG" - ./local/compile_lg.py --lang-dir $lang_char_dir -fi diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py deleted file mode 100644 index f9cdfb621..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AiShell2AsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. ios, android, mic). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to gen cuts from aishell2_cuts_train.jsonl.gz") - return load_manifest_lazy( - self.args.manifest_dir / "aishell2_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") - return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to gen cuts from aishell2_cuts_test.jsonl.gz") - return load_manifest_lazy( - self.args.manifest_dir / "aishell2_cuts_test.jsonl.gz" - ) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py b/egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py b/egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py deleted file mode 120000 index c7c1a4b6e..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py deleted file mode 100755 index 9e44b4e34..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ /dev/null @@ -1,776 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 - -(5) fast beam search (nbest) -./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless5/decode.py \ - --epoch 25 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AiShell2AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AiShell2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.unk_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell2 = AiShell2AsrDataModule(args) - - valid_cuts = aishell2.valid_cuts() - test_cuts = aishell2.test_cuts() - - # use ios sets for dev and test - dev_dl = aishell2.valid_dataloaders(valid_cuts) - test_dl = aishell2.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/aishell2/ASR/pruned_transducer_stateless5/encoder_interface.py deleted file mode 120000 index f58253127..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py deleted file mode 100755 index c92c7ab83..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,271 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --tokens ./data/lang_char/tokens.txt \ - --epoch 25 \ - --avg 5 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless5/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/aishell2/ASR - ./pruned_transducer_stateless5/decode.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py b/egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/model.py b/egs/aishell2/ASR/pruned_transducer_stateless5/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/optim.py b/egs/aishell2/ASR/pruned_transducer_stateless5/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py deleted file mode 100755 index f04632388..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ /dev/null @@ -1,335 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by -./pruned_transducer_stateless5/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.unk_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = "".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py b/egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py deleted file mode 100755 index 8c7448d4c..000000000 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ /dev/null @@ -1,1108 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# Copyright 2022 Nvidia (authors: Yuekai Zhang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --lang-dir data/lang_char \ - --num-epochs 40 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --max-duration 300 \ - --use-fp16 0 \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 - -# For mix precision training: - -./pruned_transducer_stateless5/train.py \ - --lang-dir data/lang_char \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AiShell2AsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - assert type(y) == list - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - aishell2 = AiShell2AsrDataModule(args) - - train_cuts = aishell2.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 8 seconds - # - # Caution: There is a reason to select 8.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 8.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell2.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = aishell2.valid_cuts() - valid_dl = aishell2.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - -def main(): - parser = get_parser() - AiShell2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell2/ASR/shared b/egs/aishell2/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/aishell2/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md deleted file mode 100644 index b96161762..000000000 --- a/egs/aishell4/ASR/README.md +++ /dev/null @@ -1,23 +0,0 @@ - -# Introduction - -This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). - -The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. - -(From [Open Speech and Language Resources](https://www.openslr.org/111/)) - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|-----------------------------| -| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | - -The decoder in `transducer_stateless` is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/aishell4/ASR/RESULTS.md b/egs/aishell4/ASR/RESULTS.md deleted file mode 100644 index 9bd062f1d..000000000 --- a/egs/aishell4/ASR/RESULTS.md +++ /dev/null @@ -1,117 +0,0 @@ -## Results - -### Aishell4 Char training results (Pruned Transducer Stateless5) - -#### 2022-06-13 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/399. - -When use-averaged-model=False, the CERs are -| | test | comment | -|------------------------------------|------------|------------------------------------------| -| greedy search | 30.05 | --epoch 30, --avg 25, --max-duration 800 | -| modified beam search (beam size 4) | 29.16 | --epoch 30, --avg 25, --max-duration 800 | -| fast beam search (set as default) | 29.20 | --epoch 30, --avg 25, --max-duration 1500| - -When use-averaged-model=True, the CERs are -| | test | comment | -|------------------------------------|------------|----------------------------------------------------------------------| -| greedy search | 29.89 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True | -| modified beam search (beam size 4) | 28.91 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True | -| fast beam search (set as default) | 29.08 | --iter 36000, --avg 8, --max-duration 1500 --use-averaged-model=True | - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 220 \ - --save-every-n 4000 - -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/tjaVRKERS8C10SzhpBcxSQ/#scalars - -When use-averaged-model=False, the decoding command is: -``` -epoch=30 -avg=25 - -## greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 800 - -## modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 800 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -## fast beam search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -When use-averaged-model=True, the decoding command is: -``` -iter=36000 -avg=8 - -## greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 800 \ - --use-averaged-model True - -## modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 800 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --use-averaged-model True - -## fast beam search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - --use-averaged-model True -``` - -A pre-trained model and decoding logs can be found at diff --git a/egs/aishell4/ASR/local/__init__.py b/egs/aishell4/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py deleted file mode 100755 index b5f8468ac..000000000 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aidatatang_200zh dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_aishell4( - num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False -): - src_dir = Path("data/manifests/aishell4") - output_dir = Path("data/fbank") - num_jobs = min(8, os.cpu_count()) - - dataset_parts = ( - "train_S", - "train_M", - "train_L", - "test", - ) - prefix = "aishell4" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info("Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - logging.info("About splitting cuts into smaller chunks") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - min_duration=None, - ) - - cut_set.to_file(output_dir / cuts_filename) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_aishell4( - num_mel_bins=args.num_mel_bins, - perturb_speed=args.perturb_speed, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/aishell4/ASR/local/compute_fbank_musan.py b/egs/aishell4/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/aishell4/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/aishell4/ASR/local/display_manifest_statistics.py b/egs/aishell4/ASR/local/display_manifest_statistics.py deleted file mode 100644 index b79e55eef..000000000 --- a/egs/aishell4/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -See the function `remove_short_and_long_utt()` -in ../../../librispeech/ASR/transducer/train.py -for usage. -""" - - -from lhotse import load_manifest - - -def main(): - paths = [ - "./data/fbank/cuts_train_S.json.gz", - "./data/fbank/cuts_train_M.json.gz", - "./data/fbank/cuts_train_L.json.gz", - "./data/fbank/cuts_test.json.gz", - ] - - for path in paths: - print(f"Starting display the statistics for {path}") - cuts = load_manifest(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Starting display the statistics for ./data/fbank/cuts_train_S.json.gz -Cuts count: 91995 -Total duration (hours): 95.8 -Speech duration (hours): 95.8 (100.0%) -*** -Duration statistics (seconds): -mean 3.7 -std 7.1 -min 0.1 -25% 0.9 -50% 2.5 -75% 5.4 -99% 15.3 -99.5% 17.5 -99.9% 23.3 -max 1021.7 -Starting display the statistics for ./data/fbank/cuts_train_M.json.gz -Cuts count: 177195 -Total duration (hours): 179.5 -Speech duration (hours): 179.5 (100.0%) -*** -Duration statistics (seconds): -mean 3.6 -std 6.4 -min 0.0 -25% 0.9 -50% 2.4 -75% 5.2 -99% 14.9 -99.5% 17.0 -99.9% 23.5 -max 990.4 -Starting display the statistics for ./data/fbank/cuts_train_L.json.gz -Cuts count: 37572 -Total duration (hours): 49.1 -Speech duration (hours): 49.1 (100.0%) -*** -Duration statistics (seconds): -mean 4.7 -std 4.0 -min 0.2 -25% 1.6 -50% 3.7 -75% 6.7 -99% 17.5 -99.5% 19.8 -99.9% 26.2 -max 87.4 -Starting display the statistics for ./data/fbank/cuts_test.json.gz -Cuts count: 10574 -Total duration (hours): 12.1 -Speech duration (hours): 12.1 (100.0%) -*** -Duration statistics (seconds): -mean 4.1 -std 3.4 -min 0.2 -25% 1.4 -50% 3.2 -75% 5.8 -99% 14.4 -99.5% 14.9 -99.9% 16.5 -max 17.9 -""" diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py deleted file mode 100755 index 6b440dfb3..000000000 --- a/egs/aishell4/ASR/local/prepare_char.py +++ /dev/null @@ -1,244 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/text, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import re -from pathlib import Path -from typing import Dict, List - -import k2 -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: - """Check if all the given tokens are in token symbol table. - - Args: - token_sym_table: - Token symbol table that contains all the valid tokens. - tokens: - A list of tokens. - Returns: - Return True if there is any token not in the token_sym_table, - otherwise False. - """ - for tok in tokens: - if tok not in token_sym_table: - return True - return False - - -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: - """Generate a lexicon from a word list and token_sym_table. - - Args: - token_sym_table: - Token symbol table that mapping token to token ids. - words: - A list of strings representing words. - Returns: - Return a dict whose keys are words and values are the corresponding - tokens. - """ - lexicon = [] - for word in words: - chars = list(word.strip(" \t")) - if contain_oov(token_sym_table, chars): - continue - lexicon.append((word, chars)) - - # The OOV word is - lexicon.append(("", [""])) - return lexicon - - -def generate_tokens(text_file: str) -> Dict[str, int]: - """Generate tokens from the given text file. - - Args: - text_file: - A file that contains text lines to generate tokens. - Returns: - Return a dict whose keys are tokens and values are token ids ranged - from 0 to len(keys) - 1. - """ - tokens: Dict[str, int] = dict() - tokens[""] = 0 - tokens[""] = 1 - tokens[""] = 2 - whitespace = re.compile(r"([ \t\r\n]+)") - with open(text_file, "r", encoding="utf-8") as f: - for line in f: - line = re.sub(whitespace, "", line) - chars = list(line) - for char in chars: - if char not in tokens: - tokens[char] = len(tokens) - return tokens - - -def main(): - lang_dir = Path("data/lang_char") - text_file = lang_dir / "text" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - token_sym_table = generate_tokens(text_file) - - lexicon = generate_lexicon(token_sym_table, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py deleted file mode 100755 index c8cf9b881..000000000 --- a/egs/aishell4/ASR/local/prepare_lang.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon - -Lexicon = List[Tuple[str, List[str]]] - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") - return parser.parse_args() - - -def main(): - out_dir = Path(get_args().lang_dir) - lexicon_filename = out_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/local/prepare_words.py b/egs/aishell4/ASR/local/prepare_words.py deleted file mode 100755 index 65aca2983..000000000 --- a/egs/aishell4/ASR/local/prepare_words.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input words.txt without ids: - - words_no_ids.txt -and generates the new words.txt with related ids. - - words.txt -""" - - -import argparse -import logging - -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare words.txt", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - default="data/lang_char/words_no_ids.txt", - type=str, - help="the words file without ids for WenetSpeech", - ) - parser.add_argument( - "--output-file", - default="data/lang_char/words.txt", - type=str, - help="the words file with ids for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - add_words = [" 0", "!SIL 1", " 2", " 3"] - new_lines.extend(add_words) - - logging.info("Starting reading the input file") - for i in tqdm(range(len(lines))): - x = lines[i] - idx = 4 + i - new_line = str(x.strip("\n")) + " " + str(idx) - new_lines.append(new_line) - - logging.info("Starting writing the words.txt") - f_out = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_out.write(line) - f_out.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py deleted file mode 100755 index 74e025ad7..000000000 --- a/egs/aishell4/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/local/text2segments.py b/egs/aishell4/ASR/local/text2segments.py deleted file mode 100644 index 3df727c67..000000000 --- a/egs/aishell4/ASR/local/text2segments.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text", which refers to the transcript file for -WenetSpeech: - - text -and generates the output file text_word_segmentation which is implemented -with word segmenting: - - text_words_segmentation -""" - - -import argparse - -import jieba -from tqdm import tqdm - -jieba.enable_paddle() - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Chinese Word Segmentation for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - default="data/lang_char/text", - type=str, - help="the input text file for WenetSpeech", - ) - parser.add_argument( - "--output-file", - default="data/lang_char/text_words_segmentation", - type=str, - help="the text implemented with words segmenting for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - for i in tqdm(range(len(lines))): - x = lines[i].rstrip() - seg_list = jieba.cut(x, use_paddle=True) - new_line = " ".join(seg_list) - new_lines.append(new_line) - - f_new = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_new.write(line) - f_new.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py deleted file mode 100755 index 85047c367..000000000 --- a/egs/aishell4/ASR/local/text2token.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) -# 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import codecs -import re -import sys -from typing import List - -from pypinyin import lazy_pinyin, pinyin - -is_python2 = sys.version_info[0] == 2 - - -def exist_or_not(i, match_pos): - start_pos = None - end_pos = None - for pos in match_pos: - if pos[0] <= i < pos[1]: - start_pos = pos[0] - end_pos = pos[1] - break - - return start_pos, end_pos - - -def get_parser(): - parser = argparse.ArgumentParser( - description="convert raw text to tokenized text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--nchar", - "-n", - default=1, - type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", - ) - parser.add_argument( - "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" - ) - parser.add_argument("--space", default="", type=str, help="space symbol") - parser.add_argument( - "--non-lang-syms", - "-l", - default=None, - type=str, - help="list of non-linguistic symobles, e.g., etc.", - ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") - parser.add_argument( - "--trans_type", - "-t", - type=str, - default="char", - choices=["char", "pinyin", "lazy_pinyin"], - help="""Transcript type. char/pinyin/lazy_pinyin""", - ) - return parser - - -def token2id( - texts, token_table, token_type: str = "lazy_pinyin", oov: str = "" -) -> List[List[int]]: - """Convert token to id. - Args: - texts: - The input texts, it refers to the chinese text here. - token_table: - The token table is built based on "data/lang_xxx/token.txt" - token_type: - The type of token, such as "pinyin" and "lazy_pinyin". - oov: - Out of vocabulary token. When a word(token) in the transcript - does not exist in the token list, it is replaced with `oov`. - - Returns: - The list of ids for the input texts. - """ - if texts is None: - raise ValueError("texts can't be None!") - else: - oov_id = token_table[oov] - ids: List[List[int]] = [] - for text in texts: - chars_list = list(str(text)) - if token_type == "lazy_pinyin": - text = lazy_pinyin(chars_list) - sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text - ] - ids.append(sub_ids) - else: # token_type = "pinyin" - text = pinyin(chars_list) - sub_ids = [ - token_table[txt[0]] if txt[0] in token_table else oov_id - for txt in text - ] - ids.append(sub_ids) - return ids - - -def main(): - parser = get_parser() - args = parser.parse_args() - - rs = [] - if args.non_lang_syms is not None: - with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: - nls = [x.rstrip() for x in f.readlines()] - rs = [re.compile(re.escape(x)) for x in nls] - - if args.text: - f = codecs.open(args.text, encoding="utf-8") - else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) - - sys.stdout = codecs.getwriter("utf-8")( - sys.stdout if is_python2 else sys.stdout.buffer - ) - line = f.readline() - n = args.nchar - while line: - x = line.split() - print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols :]) # noqa E203 - - # get all matched positions - match_pos = [] - for r in rs: - i = 0 - while i >= 0: - m = r.search(a, i) - if m: - match_pos.append([m.start(), m.end()]) - i = m.end() - else: - break - if len(match_pos) > 0: - chars = [] - i = 0 - while i < len(a): - start_pos, end_pos = exist_or_not(i, match_pos) - if start_pos is not None: - chars.append(a[start_pos:end_pos]) - i = end_pos - else: - chars.append(a[i]) - i += 1 - a = chars - - if args.trans_type == "pinyin": - a = pinyin(list(str(a))) - a = [one[0] for one in a] - - if args.trans_type == "lazy_pinyin": - a = lazy_pinyin(list(str(a))) - - a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203 - - a_flat = [] - for z in a: - a_flat.append("".join(z)) - - a_chars = "".join(a_flat) - print(a_chars) - line = f.readline() - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/local/text_normalize.py b/egs/aishell4/ASR/local/text_normalize.py deleted file mode 100755 index 5650be502..000000000 --- a/egs/aishell4/ASR/local/text_normalize.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text_full", which includes three transcript files -(train_S, train_M and train_L) for AISHELL4: - - text_full -and generates the output file text_normalize which is implemented -to normalize text: - - text -""" - - -import argparse - -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Normalizing for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input", - default="data/lang_char/text_full", - type=str, - help="the input text files for AISHELL4", - ) - parser.add_argument( - "--output", - default="data/lang_char/text", - type=str, - help="the text implemented with normalizer for AISHELL4", - ) - - return parser - - -def text_normalize(str_line: str): - line = str_line.strip().rstrip("\n") - line = line.replace(" ", "") - line = line.replace("", "") - line = line.replace("<%>", "") - line = line.replace("<->", "") - line = line.replace("<$>", "") - line = line.replace("<#>", "") - line = line.replace("<_>", "") - line = line.replace("", "") - line = line.replace("`", "") - line = line.replace("&", "") - line = line.replace(",", "") - line = line.replace("A", "") - line = line.replace("a", "A") - line = line.replace("b", "B") - line = line.replace("c", "C") - line = line.replace("k", "K") - line = line.replace("t", "T") - line = line.replace(",", "") - line = line.replace("丶", "") - line = line.replace("。", "") - line = line.replace("、", "") - line = line.replace("?", "") - line = line.replace("·", "") - line = line.replace("*", "") - line = line.replace("!", "") - line = line.replace("$", "") - line = line.replace("+", "") - line = line.replace("-", "") - line = line.replace("\\", "") - line = line.replace("?", "") - line = line.replace("¥", "") - line = line.replace("%", "") - line = line.replace(".", "") - line = line.replace("<", "") - line = line.replace("&", "") - line = line.upper() - - return line - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input - output_file = args.output - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - for i in tqdm(range(len(lines))): - new_line = text_normalize(lines[i]) - new_lines.append(new_line) - - f_new = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_new.write(line) - f_new.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh deleted file mode 100755 index 38a36d97a..000000000 --- a/egs/aishell4/ASR/prepare.sh +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=7 -perturb_speed=true - - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/aishell4 -# You can find four directories:train_S, train_M, train_L and test. -# You can download it from https://openslr.org/111/ -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/aishell4, - # you can create a symlink - # - # ln -sfv /path/to/aishell4 $dl_dir/aishell4 - # - if [ ! -f $dl_dir/aishell4/train_L ]; then - lhotse download aishell4 $dl_dir/aishell4 - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/musan - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare aishell4 manifest" - # We assume that you have downloaded the aishell4 corpus - # to $dl_dir/aishell4 - if [ ! -f data/manifests/aishell4/.manifests.done ]; then - mkdir -p data/manifests/aishell4 - lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4 - touch data/manifests/aishell4/.manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute fbank for aishell4" - if [ ! -f data/fbank/aishell4/.fbank.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} - touch data/fbank/.fbank.done - fi -fi - -whisper_mel_bins=80 -if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then - log "Stage 20: Compute whisper fbank for aishell4" - if [ ! -f data/fbank/aishell4/.fbank.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.fbank.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare char based lang" - lang_char_dir=data/lang_char - mkdir -p $lang_char_dir - - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - gunzip -c data/manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_S - - gunzip -c data/manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_M - - gunzip -c data/manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_L - - for r in text_S text_M text_L ; do - cat $lang_char_dir/$r >> $lang_char_dir/text_full - done - - # Prepare text normalize - python ./local/text_normalize.py \ - --input $lang_char_dir/text_full \ - --output $lang_char_dir/text - - # Prepare words segments - python ./local/text2segments.py \ - --input $lang_char_dir/text \ - --output $lang_char_dir/text_words_segmentation - - cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \ - | sort -u | sed "/^$/d" \ - | uniq > $lang_char_dir/words_no_ids.txt - - # Prepare words.txt - if [ ! -f $lang_char_dir/words.txt ]; then - ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt - fi - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py - fi -fi diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell4/ASR/pruned_transducer_stateless5/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py deleted file mode 100644 index c10456da5..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class Aishell4AsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - - group.add_argument( - "--num-buckets", - type=int, - default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - rank=0, - world_size=1, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - rank=0, - world_size=1, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_S_cuts(self) -> CutSet: - logging.info("About to get S train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz" - ) - - @lru_cache() - def train_M_cuts(self) -> CutSet: - logging.info("About to get M train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz" - ) - - @lru_cache() - def train_L_cuts(self) -> CutSet: - logging.info("About to get L train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - # Aishell4 doesn't have dev data, here use test to replace dev. - return load_manifest_lazy( - self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz" - ) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/beam_search.py b/egs/aishell4/ASR/pruned_transducer_stateless5/beam_search.py deleted file mode 120000 index ed78bd4bb..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py b/egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py deleted file mode 120000 index c7c1a4b6e..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py deleted file mode 100755 index 068e2749a..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ /dev/null @@ -1,623 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When use-averaged-model=True, usage: -(1) greedy search -./pruned_transducer_stateless5/decode.py \ - --iter 36000 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 800 \ - --decoding-method greedy_search \ - --use-averaged-model True - -(2) modified beam search -./pruned_transducer_stateless5/decode.py \ - --iter 36000 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 800 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --use-averaged-model True - -(3) fast beam search -./pruned_transducer_stateless5/decode.py \ - --iter 36000 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 800 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - --use-averaged-model True -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import Aishell4AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from local.text_normalize import text_normalize -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - Aishell4AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - def text_normalize_for_cut(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = text.strip("\n").strip("\t") - c.supervisions[0].text = text_normalize(text) - return c - - # we need cut ids to display recognition results. - args.return_cuts = True - aishell4 = Aishell4AsrDataModule(args) - test_cuts = aishell4.test_cuts() - test_cuts = test_cuts.map(text_normalize_for_cut) - test_dl = aishell4.test_dataloaders(test_cuts) - - test_sets = ["test"] - test_dl = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decoder.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decoder.py deleted file mode 120000 index 8a5e07bd5..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/aishell4/ASR/pruned_transducer_stateless5/encoder_interface.py deleted file mode 120000 index 2fc10439b..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py deleted file mode 100755 index 246820833..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,278 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless5/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/aishell4/ASR - ./pruned_transducer_stateless5/decode.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/joiner.py b/egs/aishell4/ASR/pruned_transducer_stateless5/joiner.py deleted file mode 120000 index f31b5fd9b..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/local b/egs/aishell4/ASR/pruned_transducer_stateless5/local deleted file mode 120000 index c820590c5..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/local +++ /dev/null @@ -1 +0,0 @@ -../local \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/model.py b/egs/aishell4/ASR/pruned_transducer_stateless5/model.py deleted file mode 120000 index be059ba7c..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/optim.py b/egs/aishell4/ASR/pruned_transducer_stateless5/optim.py deleted file mode 120000 index 661206562..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py deleted file mode 100755 index e8b7f71b7..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When use-averaged-model=True, usage: - -(1) greedy search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir data/lang_char \ - --decoding-method greedy_search \ - --use-averaged-model True \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir data/lang_char \ - --use-averaged-model True \ - --decoding-method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search (not suggest) -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir data/lang_char \ - --use-averaged-model True \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir data/lang_char \ - --use-averaged-model True \ - --decoding-method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by -./pruned_transducer_stateless5/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --decoding-method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.decoding_method}" - if params.decoding_method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding-method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/scaling.py b/egs/aishell4/ASR/pruned_transducer_stateless5/scaling.py deleted file mode 120000 index be7b111c6..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py b/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py deleted file mode 100755 index d42c3b4f4..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/aishell4/ASR - python ./pruned_transducer_stateless5/test_model.py -""" - -from train import get_params, get_transducer_model - - -def test_model_1(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = 24 - params.dim_feedforward = 1536 # 384 * 4 - params.encoder_dim = 384 - model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf -def test_model_M(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = 18 - params.dim_feedforward = 1024 - params.encoder_dim = 256 - params.nhead = 4 - params.decoder_dim = 512 - params.joiner_dim = 512 - model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -def main(): - # test_model_1() - test_model_M() - - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py deleted file mode 100755 index a354f761e..000000000 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ /dev/null @@ -1,1079 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import Aishell4AsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from local.text_normalize import text_normalize -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 100, - "valid_interval": 200, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 400, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - if type(y) == list: - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - # print(batch["supervisions"]) - - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - aishell4 = Aishell4AsrDataModule(args) - # Combine all of the training data - train_cuts = aishell4.train_S_cuts() - train_cuts += aishell4.train_M_cuts() - train_cuts += aishell4.train_L_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 20.0 - - def text_normalize_for_cut(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = text.strip("\n").strip("\t") - c.supervisions[0].text = text_normalize(text) - return c - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.map(text_normalize_for_cut) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = aishell4.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = aishell4.valid_cuts() - valid_cuts = valid_cuts.map(text_normalize_for_cut) - valid_dl = aishell4.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - Aishell4AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/aishell4/ASR/shared b/egs/aishell4/ASR/shared deleted file mode 120000 index 3a3b28f96..000000000 --- a/egs/aishell4/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../egs/aishell/ASR/shared \ No newline at end of file diff --git a/egs/alimeeting/ASR/README.md b/egs/alimeeting/ASR/README.md deleted file mode 100644 index 257fe38d5..000000000 --- a/egs/alimeeting/ASR/README.md +++ /dev/null @@ -1,19 +0,0 @@ - -# Introduction - -This recipe includes some different ASR models trained with Alimeeting (far). - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|-----------------------------| -| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | - -The decoder in `transducer_stateless` is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/alimeeting/ASR/RESULTS.md b/egs/alimeeting/ASR/RESULTS.md deleted file mode 100644 index 745795a20..000000000 --- a/egs/alimeeting/ASR/RESULTS.md +++ /dev/null @@ -1,71 +0,0 @@ -## Results - -### Alimeeting Char training results (Pruned Transducer Stateless2) - -#### 2022-06-01 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/378. - -The WERs are -| | eval | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 31.77 | 34.66 | --epoch 29, --avg 18, --max-duration 100 | -| modified beam search (beam size 4) | 30.38 | 33.02 | --epoch 29, --avg 18, --max-duration 100 | -| fast beam search (set as default) | 31.39 | 34.25 | --epoch 29, --avg 18, --max-duration 1500| - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 220 \ - --save-every-n 1000 - -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/AoqgSvZKTZCJhJbOuG3W6g/#scalars - -The decoding command is: -``` -epoch=29 -avg=18 - -## greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir ./data/lang_char \ - --max-duration 100 - -## modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir ./data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -## fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir ./data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -A pre-trained model and decoding logs can be found at diff --git a/egs/alimeeting/ASR/local/__init__.py b/egs/alimeeting/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py deleted file mode 100755 index 09c873a34..000000000 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aishell dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_alimeeting( - num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False -): - src_dir = Path("data/manifests/alimeeting") - output_dir = Path("data/fbank") - num_jobs = min(8, os.cpu_count()) - - dataset_parts = ( - "train", - "eval", - "test", - ) - - prefix = "alimeeting-far" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info("Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cur_num_jobs = num_jobs if ex is None else 80 - cur_num_jobs = min(cur_num_jobs, len(cut_set)) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=cur_num_jobs, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - logging.info("About splitting cuts into smaller chunks") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - min_duration=None, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use the Whisper Fbank feature extractor. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_alimeeting( - num_mel_bins=args.num_mel_bins, - perturb_speed=args.perturb_speed, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/alimeeting/ASR/local/compute_fbank_musan.py b/egs/alimeeting/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/alimeeting/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/local/display_manifest_statistics.py b/egs/alimeeting/ASR/local/display_manifest_statistics.py deleted file mode 100644 index 16cdecc91..000000000 --- a/egs/alimeeting/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -See the function `remove_short_and_long_utt()` -in ../../../librispeech/ASR/transducer/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - paths = [ - "./data/fbank/alimeeting_cuts_train.jsonl.gz", - "./data/fbank/alimeeting_cuts_eval.jsonl.gz", - "./data/fbank/alimeeting_cuts_test.jsonl.gz", - ] - - for path in paths: - print(f"Starting display the statistics for {path}") - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Starting display the statistics for ./data/fbank/alimeeting_cuts_train.jsonl.gz -Cuts count: 559092 -Total duration (hours): 424.6 -Speech duration (hours): 424.6 (100.0%) -*** -Duration statistics (seconds): -mean 2.7 -std 3.0 -min 0.0 -25% 0.7 -50% 1.7 -75% 3.6 -99% 13.6 -99.5% 14.7 -99.9% 16.2 -max 284.3 -Starting display the statistics for ./data/fbank/alimeeting_cuts_eval.jsonl.gz -Cuts count: 6457 -Total duration (hours): 4.9 -Speech duration (hours): 4.9 (100.0%) -*** -Duration statistics (seconds): -mean 2.7 -std 3.1 -min 0.1 -25% 0.6 -50% 1.6 -75% 3.5 -99% 13.6 -99.5% 14.1 -99.9% 14.7 -max 15.8 -Starting display the statistics for ./data/fbank/alimeeting_cuts_test.jsonl.gz -Cuts count: 16358 -Total duration (hours): 12.5 -Speech duration (hours): 12.5 (100.0%) -*** -Duration statistics (seconds): -mean 2.7 -std 2.9 -min 0.1 -25% 0.7 -50% 1.7 -75% 3.5 -99% 13.7 -99.5% 14.2 -99.9% 14.8 -max 15.7 -""" diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py deleted file mode 100755 index 6b440dfb3..000000000 --- a/egs/alimeeting/ASR/local/prepare_char.py +++ /dev/null @@ -1,244 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/text, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import re -from pathlib import Path -from typing import Dict, List - -import k2 -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: - """Check if all the given tokens are in token symbol table. - - Args: - token_sym_table: - Token symbol table that contains all the valid tokens. - tokens: - A list of tokens. - Returns: - Return True if there is any token not in the token_sym_table, - otherwise False. - """ - for tok in tokens: - if tok not in token_sym_table: - return True - return False - - -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: - """Generate a lexicon from a word list and token_sym_table. - - Args: - token_sym_table: - Token symbol table that mapping token to token ids. - words: - A list of strings representing words. - Returns: - Return a dict whose keys are words and values are the corresponding - tokens. - """ - lexicon = [] - for word in words: - chars = list(word.strip(" \t")) - if contain_oov(token_sym_table, chars): - continue - lexicon.append((word, chars)) - - # The OOV word is - lexicon.append(("", [""])) - return lexicon - - -def generate_tokens(text_file: str) -> Dict[str, int]: - """Generate tokens from the given text file. - - Args: - text_file: - A file that contains text lines to generate tokens. - Returns: - Return a dict whose keys are tokens and values are token ids ranged - from 0 to len(keys) - 1. - """ - tokens: Dict[str, int] = dict() - tokens[""] = 0 - tokens[""] = 1 - tokens[""] = 2 - whitespace = re.compile(r"([ \t\r\n]+)") - with open(text_file, "r", encoding="utf-8") as f: - for line in f: - line = re.sub(whitespace, "", line) - chars = list(line) - for char in chars: - if char not in tokens: - tokens[char] = len(tokens) - return tokens - - -def main(): - lang_dir = Path("data/lang_char") - text_file = lang_dir / "text" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - token_sym_table = generate_tokens(text_file) - - lexicon = generate_lexicon(token_sym_table, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py deleted file mode 100755 index c8cf9b881..000000000 --- a/egs/alimeeting/ASR/local/prepare_lang.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon - -Lexicon = List[Tuple[str, List[str]]] - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") - return parser.parse_args() - - -def main(): - out_dir = Path(get_args().lang_dir) - lexicon_filename = out_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/local/prepare_words.py b/egs/alimeeting/ASR/local/prepare_words.py deleted file mode 100755 index 65aca2983..000000000 --- a/egs/alimeeting/ASR/local/prepare_words.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input words.txt without ids: - - words_no_ids.txt -and generates the new words.txt with related ids. - - words.txt -""" - - -import argparse -import logging - -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare words.txt", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - default="data/lang_char/words_no_ids.txt", - type=str, - help="the words file without ids for WenetSpeech", - ) - parser.add_argument( - "--output-file", - default="data/lang_char/words.txt", - type=str, - help="the words file with ids for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - add_words = [" 0", "!SIL 1", " 2", " 3"] - new_lines.extend(add_words) - - logging.info("Starting reading the input file") - for i in tqdm(range(len(lines))): - x = lines[i] - idx = 4 + i - new_line = str(x.strip("\n")) + " " + str(idx) - new_lines.append(new_line) - - logging.info("Starting writing the words.txt") - f_out = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_out.write(line) - f_out.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py deleted file mode 100755 index 74e025ad7..000000000 --- a/egs/alimeeting/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py deleted file mode 100644 index 27b904fc8..000000000 --- a/egs/alimeeting/ASR/local/text2segments.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text", which refers to the transcript file for -WenetSpeech: - - text -and generates the output file text_word_segmentation which is implemented -with word segmenting: - - text_words_segmentation -""" - - -import argparse - -import jieba -import paddle -from tqdm import tqdm - -paddle.enable_static() -jieba.enable_paddle() - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Chinese Word Segmentation for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - default="data/lang_char/text", - type=str, - help="the input text file for WenetSpeech", - ) - parser.add_argument( - "--output-file", - default="data/lang_char/text_words_segmentation", - type=str, - help="the text implemented with words segmenting for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - for i in tqdm(range(len(lines))): - x = lines[i].rstrip() - seg_list = jieba.cut(x, use_paddle=True) - new_line = " ".join(seg_list) - new_lines.append(new_line) - - f_new = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_new.write(line) - f_new.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py deleted file mode 100755 index 85047c367..000000000 --- a/egs/alimeeting/ASR/local/text2token.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) -# 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import codecs -import re -import sys -from typing import List - -from pypinyin import lazy_pinyin, pinyin - -is_python2 = sys.version_info[0] == 2 - - -def exist_or_not(i, match_pos): - start_pos = None - end_pos = None - for pos in match_pos: - if pos[0] <= i < pos[1]: - start_pos = pos[0] - end_pos = pos[1] - break - - return start_pos, end_pos - - -def get_parser(): - parser = argparse.ArgumentParser( - description="convert raw text to tokenized text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--nchar", - "-n", - default=1, - type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", - ) - parser.add_argument( - "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" - ) - parser.add_argument("--space", default="", type=str, help="space symbol") - parser.add_argument( - "--non-lang-syms", - "-l", - default=None, - type=str, - help="list of non-linguistic symobles, e.g., etc.", - ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") - parser.add_argument( - "--trans_type", - "-t", - type=str, - default="char", - choices=["char", "pinyin", "lazy_pinyin"], - help="""Transcript type. char/pinyin/lazy_pinyin""", - ) - return parser - - -def token2id( - texts, token_table, token_type: str = "lazy_pinyin", oov: str = "" -) -> List[List[int]]: - """Convert token to id. - Args: - texts: - The input texts, it refers to the chinese text here. - token_table: - The token table is built based on "data/lang_xxx/token.txt" - token_type: - The type of token, such as "pinyin" and "lazy_pinyin". - oov: - Out of vocabulary token. When a word(token) in the transcript - does not exist in the token list, it is replaced with `oov`. - - Returns: - The list of ids for the input texts. - """ - if texts is None: - raise ValueError("texts can't be None!") - else: - oov_id = token_table[oov] - ids: List[List[int]] = [] - for text in texts: - chars_list = list(str(text)) - if token_type == "lazy_pinyin": - text = lazy_pinyin(chars_list) - sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text - ] - ids.append(sub_ids) - else: # token_type = "pinyin" - text = pinyin(chars_list) - sub_ids = [ - token_table[txt[0]] if txt[0] in token_table else oov_id - for txt in text - ] - ids.append(sub_ids) - return ids - - -def main(): - parser = get_parser() - args = parser.parse_args() - - rs = [] - if args.non_lang_syms is not None: - with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: - nls = [x.rstrip() for x in f.readlines()] - rs = [re.compile(re.escape(x)) for x in nls] - - if args.text: - f = codecs.open(args.text, encoding="utf-8") - else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) - - sys.stdout = codecs.getwriter("utf-8")( - sys.stdout if is_python2 else sys.stdout.buffer - ) - line = f.readline() - n = args.nchar - while line: - x = line.split() - print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols :]) # noqa E203 - - # get all matched positions - match_pos = [] - for r in rs: - i = 0 - while i >= 0: - m = r.search(a, i) - if m: - match_pos.append([m.start(), m.end()]) - i = m.end() - else: - break - if len(match_pos) > 0: - chars = [] - i = 0 - while i < len(a): - start_pos, end_pos = exist_or_not(i, match_pos) - if start_pos is not None: - chars.append(a[start_pos:end_pos]) - i = end_pos - else: - chars.append(a[i]) - i += 1 - a = chars - - if args.trans_type == "pinyin": - a = pinyin(list(str(a))) - a = [one[0] for one in a] - - if args.trans_type == "lazy_pinyin": - a = lazy_pinyin(list(str(a))) - - a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203 - - a_flat = [] - for z in a: - a_flat.append("".join(z)) - - a_chars = "".join(a_flat) - print(a_chars) - line = f.readline() - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh deleted file mode 100755 index 55f9f019b..000000000 --- a/egs/alimeeting/ASR/prepare.sh +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=7 -perturb_speed=true - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/alimeeting -# This directory contains the following files downloaded from -# https://openslr.org/119/ -# -# - Train_Ali_far.tar.gz -# - Train_Ali_near.tar.gz -# - Test_Ali.tar.gz -# - Eval_Ali.tar.gz -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then - lhotse download ali-meeting $dl_dir/alimeeting - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare alimeeting manifest" - # We assume that you have downloaded the alimeeting corpus - # to $dl_dir/alimeeting - if [ ! -f data/manifests/alimeeting/.manifests.done ]; then - mkdir -p data/manifests/alimeeting - lhotse prepare ali-meeting $dl_dir/alimeeting data/manifests/alimeeting - touch data/manifests/alimeeting/.manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: compute fbank for alimeeting" - if [ ! -f data/fbank/.fbank.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} - touch data/fbank/.fbank.done - fi -fi - -whisper_mel_bins=80 -if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then - log "Stage 20: compute whisper fbank for alimeeting" - if [ ! -f data/fbank/.fbank.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.fbank.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare char based lang" - lang_char_dir=data/lang_char - mkdir -p $lang_char_dir - - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - gunzip -c data/manifests/alimeeting/alimeeting_supervisions_train.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text - - # Prepare words segments - python ./local/text2segments.py \ - --input $lang_char_dir/text \ - --output $lang_char_dir/text_words_segmentation - - cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \ - | sort -u | sed "/^$/d" \ - | uniq > $lang_char_dir/words_no_ids.txt - - # Prepare words.txt - if [ ! -f $lang_char_dir/words.txt ]; then - ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt - fi - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py - fi -fi diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/__init__.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py deleted file mode 100644 index 410741215..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ /dev/null @@ -1,412 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - load_manifest, - load_manifest_lazy, - set_caching_enabled, -) -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - -set_caching_enabled(False) -torch.set_num_threads(1) - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AlimeetingAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/dev/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - dev_iter_dataset = IterableDatasetWrapper( - dataset=validate, - sampler=valid_sampler, - ) - valid_dl = DataLoader( - dev_iter_dataset, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - test_iter_dataset = IterableDatasetWrapper( - dataset=test, - sampler=sampler, - ) - test_dl = DataLoader( - test_iter_dataset, - batch_size=None, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "alimeeting_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "alimeeting_cuts_eval.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "alimeeting_cuts_test.jsonl.gz" - ) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/beam_search.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/conformer.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/conformer.py deleted file mode 120000 index a65957180..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py deleted file mode 100755 index 6c170c392..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1,596 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When training with the far data, usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 29 \ - --avg 18 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 29 \ - --avg 18 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 29 \ - --avg 18 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AlimeetingAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--batch", - type=int, - default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 50 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AlimeetingAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - elif params.batch is not None: - filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt" - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints([filenames], device=device)) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - average = average_checkpoints(filenames, device=device) - checkpoint = {"model": average} - torch.save( - checkpoint, - "pruned_transducer_stateless2/exp/pretrained_epoch_29_avg_18.pt", - ) - - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset - - # we need cut ids to display recognition results. - args.return_cuts = True - alimeeting = AlimeetingAsrDataModule(args) - - dev = "eval" - test = "test" - - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev) - dev_cuts = alimeeting.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test}/shared-0.tar"): - os.makedirs(test) - test_cuts = alimeeting.test_cuts() - export_to_webdataset( - test_cuts, - output_path=f"{test}/shared-%d.tar", - shard_size=300, - ) - - dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) - ] - cuts_test_webdataset = CutSet.from_webdataset( - test_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - def remove_short_and_long_utt(c: Cut): - return 1.0 <= c.duration - - cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) - cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt) - - dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) - test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decoder.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py deleted file mode 100644 index 5dc73c52b..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens ./data/lang_char/tokens.txt \ - --epoch 29 \ - --avg 18 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless2/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/alimeeting/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/joiner.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/model.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/optim.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py deleted file mode 100644 index a738bb3fb..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ /dev/null @@ -1,339 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Here, the far data is used for training, usage: - -(1) greedy search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method greedy_search \ - --max-sym-per-frame 1 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by -./pruned_transducer_stateless2/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --method is beam_search and modified_beam_search ", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - hyps = [] - msg = f"Using {params.decoding_method}" - logging.info(msg) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/scaling.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py deleted file mode 100644 index 30154291d..000000000 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ /dev/null @@ -1,958 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 220 \ - --save-every-n 1000 - -# For mix precision training: - -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 220 \ - --save-every-n 1000 - --use-fp16 True - -""" - -import argparse -import logging -import os -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AlimeetingAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12359, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - Explanation of options saved in `params`: - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - best_train_epoch: It is the epoch that has the best training loss. - - best_valid_epoch: It is the epoch that has the best validation loss. - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - feature_dim: The model input dim. It has to match the one used - in computing features. - - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 10, - "log_interval": 1, - "reset_interval": 200, - "valid_interval": 400, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - # parameters for Noam - "model_warm_step": 200, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - - y = graph_compiler.texts_to_ids(texts) - if type(y) == list: - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - alimeeting = AlimeetingAsrDataModule(args) - - train_cuts = alimeeting.train_cuts() - valid_cuts = alimeeting.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15.0 seconds - # - # Caution: There is a reason to select 10.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 15.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - valid_dl = alimeeting.valid_dataloaders(valid_cuts) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = alimeeting.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - if not params.print_diagnostics and params.start_batch == 0: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - AlimeetingAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR/shared b/egs/alimeeting/ASR/shared deleted file mode 120000 index 3a3b28f96..000000000 --- a/egs/alimeeting/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../egs/aishell/ASR/shared \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/README.md b/egs/alimeeting/ASR_v2/README.md deleted file mode 100644 index f70327501..000000000 --- a/egs/alimeeting/ASR_v2/README.md +++ /dev/null @@ -1,38 +0,0 @@ - -# Introduction - -This recipe trains multi-domain ASR models for AliMeeting. By multi-domain, we mean that -we train a single model on close-talk and far-field conditions. This recipe optionally -uses [GSS]-based enhancement for far-field array microphone. -We pool data in the following 4 ways and train a single model on the pooled data: - -(i) individual headset microphone (IHM) -(ii) IHM with simulated reverb -(iii) Single distant microphone (SDM) -(iv) GSS-enhanced array microphones - -This is different from `alimeeting/ASR` since that recipe trains a model only on the -far-field audio. Additionally, we use text normalization here similar to the original -M2MeT challenge, so the results should be more comparable to those from Table 4 of -the [paper](https://arxiv.org/abs/2110.07393). - -The following additional packages need to be installed to run this recipe: -* `pip install jieba` -* `pip install paddlepaddle` -* `pip install git+https://github.com/desh2608/gss.git` - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -## Performance Record - -### pruned_transducer_stateless7 - -The following are decoded using `modified_beam_search`: - -| Evaluation set | eval WER | test WER | -|--------------------------|------------|---------| -| IHM | 9.58 | 11.53 | -| SDM | 23.37 | 25.85 | -| MDM (GSS-enhanced) | 11.82 | 14.22 | - -See [RESULTS](/egs/alimeeting/ASR_v2/RESULTS.md) for details. diff --git a/egs/alimeeting/ASR_v2/RESULTS.md b/egs/alimeeting/ASR_v2/RESULTS.md deleted file mode 100644 index 15b24250d..000000000 --- a/egs/alimeeting/ASR_v2/RESULTS.md +++ /dev/null @@ -1,90 +0,0 @@ -## Results (CER) - -#### 2022-12-09 - -#### Zipformer (pruned_transducer_stateless7) - -Zipformer encoder + non-current decoder. The decoder -contains only an embedding layer, a Conv1d (with kernel size 2) and a linear -layer (to transform tensor dim). - -All the results below are using a single model that is trained by combining the following -data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise -augmentation are applied on top of the pooled data. - -**WERs for IHM:** - -| | eval | test | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 10.13 | 12.21 | --epoch 15 --avg 8 --max-duration 500 | -| modified beam search | 9.58 | 11.53 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 | -| fast beam search | 9.92 | 12.07 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | - -**WERs for SDM:** - -| | eval | test | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 23.70 | 26.41 | --epoch 15 --avg 8 --max-duration 500 | -| modified beam search | 23.37 | 25.85 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 | -| fast beam search | 23.60 | 26.38 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | - -**WERs for GSS-enhanced MDM:** - -| | eval | test | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 12.24 | 14.99 | --epoch 15 --avg 8 --max-duration 500 | -| modified beam search | 11.82 | 14.22 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 | -| fast beam search | 12.30 | 14.98 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 15 \ - --exp-dir pruned_transducer_stateless7/exp \ - --max-duration 300 \ - --max-cuts 100 \ - --prune-range 5 \ - --lr-factor 5 \ - --lm-scale 0.25 \ - --use-fp16 True -``` - -The decoding command is: -``` -# greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 15 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method greedy_search - -# modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 15 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -# fast beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 15 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -Pretrained model is available at - -The tensorboard training log can be found at - diff --git a/egs/alimeeting/ASR_v2/local/__init__.py b/egs/alimeeting/ASR_v2/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py deleted file mode 100755 index 833d11c72..000000000 --- a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the AliMeeting dataset. -For the training data, we prepare IHM, reverberated IHM, SDM, and GSS-enhanced -audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced -parts (which are the 3 evaluation settings). -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" -import argparse -import logging -from pathlib import Path - -import torch -import torch.multiprocessing -from lhotse import CutSet, LilcomChunkyWriter -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_ami(perturb_speed: bool = False): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - manifests_ihm = read_manifests_if_cached( - dataset_parts=["train", "eval", "test"], - output_dir=src_dir, - prefix="alimeeting-ihm", - suffix="jsonl.gz", - ) - manifests_sdm = read_manifests_if_cached( - dataset_parts=["train", "eval", "test"], - output_dir=src_dir, - prefix="alimeeting-sdm", - suffix="jsonl.gz", - ) - # For GSS we already have cuts so we read them directly. - manifests_gss = read_manifests_if_cached( - dataset_parts=["train", "eval", "test"], - output_dir=src_dir, - prefix="alimeeting-gss", - suffix="jsonl.gz", - ) - - def _extract_feats( - cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool - ) -> None: - if speed_perturb: - logging.info(f"Doing speed perturb") - cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) - _ = cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=storage_path, - manifest_path=manifest_path, - batch_duration=5000, - num_workers=8, - storage_type=LilcomChunkyWriter, - ) - - logging.info( - "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)" - ) - - logging.info("Processing train split IHM") - cuts_ihm = ( - CutSet.from_manifests(**manifests_ihm["train"]) - .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) - .modify_ids(lambda x: x + "-ihm") - ) - _extract_feats( - cuts_ihm, - output_dir / "feats_train_ihm", - src_dir / "cuts_train_ihm.jsonl.gz", - perturb_speed, - ) - - logging.info("Processing train split IHM + reverberated IHM") - cuts_ihm_rvb = cuts_ihm.reverb_rir() - _extract_feats( - cuts_ihm_rvb, - output_dir / "feats_train_ihm_rvb", - src_dir / "cuts_train_ihm_rvb.jsonl.gz", - perturb_speed, - ) - - logging.info("Processing train split SDM") - cuts_sdm = ( - CutSet.from_manifests(**manifests_sdm["train"]) - .trim_to_supervisions(keep_overlapping=False) - .modify_ids(lambda x: x + "-sdm") - ) - _extract_feats( - cuts_sdm, - output_dir / "feats_train_sdm", - src_dir / "cuts_train_sdm.jsonl.gz", - perturb_speed, - ) - - logging.info("Processing train split GSS") - cuts_gss = ( - CutSet.from_manifests(**manifests_gss["train"]) - .trim_to_supervisions(keep_overlapping=False) - .modify_ids(lambda x: x + "-gss") - ) - _extract_feats( - cuts_gss, - output_dir / "feats_train_gss", - src_dir / "cuts_train_gss.jsonl.gz", - perturb_speed, - ) - - logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") - for split in ["eval", "test"]: - logging.info(f"Processing {split} IHM") - cuts_ihm = ( - CutSet.from_manifests(**manifests_ihm[split]) - .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_{split}_ihm", - manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - ) - logging.info(f"Processing {split} SDM") - cuts_sdm = ( - CutSet.from_manifests(**manifests_sdm[split]) - .trim_to_supervisions(keep_overlapping=False) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_{split}_sdm", - manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - ) - logging.info(f"Processing {split} GSS") - cuts_gss = ( - CutSet.from_manifests(**manifests_gss[split]) - .trim_to_supervisions(keep_overlapping=False) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_{split}_gss", - manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - ) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - - compute_fbank_ami(perturb_speed=args.perturb_speed) diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py deleted file mode 100644 index f1512efa5..000000000 --- a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/local/bin/python -# -*- coding: utf-8 -*- -# Data preparation for AliMeeting GSS-enhanced dataset. - -import logging -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path - -from lhotse import Recording, RecordingSet, SupervisionSet -from lhotse.qa import fix_manifests -from lhotse.recipes.utils import read_manifests_if_cached -from lhotse.utils import fastcopy -from tqdm import tqdm - -logging.basicConfig( - format="%(asctime)s %(levelname)-8s %(message)s", - level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S", -) - - -def get_args(): - import argparse - - parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.") - parser.add_argument( - "manifests_dir", - type=Path, - help="Path to directory containing AliMeeting manifests.", - ) - parser.add_argument( - "enhanced_dir", - type=Path, - help="Path to enhanced data directory.", - ) - parser.add_argument( - "--num-jobs", - "-j", - type=int, - default=1, - help="Number of parallel jobs to run.", - ) - parser.add_argument( - "--min-segment-duration", - "-d", - type=float, - default=0.0, - help="Minimum duration of a segment in seconds.", - ) - return parser.parse_args() - - -def find_recording_and_create_new_supervision(enhanced_dir, supervision): - """ - Given a supervision (corresponding to original AMI recording), this function finds the - enhanced recording correspoding to the supervision, and returns this recording and - a new supervision whose start and end times are adjusted to match the enhanced recording. - """ - file_name = Path( - f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac" - ) - save_path = enhanced_dir / f"{supervision.recording_id}" / file_name - if save_path.exists(): - recording = Recording.from_file(save_path) - if recording.duration == 0: - logging.warning(f"Skipping {save_path} which has duration 0 seconds.") - return None - - # Old supervision is wrt to the original recording, we create new supervision - # wrt to the enhanced segment - new_supervision = fastcopy( - supervision, - recording_id=recording.id, - start=0, - duration=recording.duration, - ) - return recording, new_supervision - else: - logging.warning(f"{save_path} does not exist.") - return None - - -def main(args): - # Get arguments - manifests_dir = args.manifests_dir - enhanced_dir = args.enhanced_dir - - # Load manifests from cache if they exist (saves time) - manifests = read_manifests_if_cached( - dataset_parts=["train", "eval", "test"], - output_dir=manifests_dir, - prefix="alimeeting-sdm", - suffix="jsonl.gz", - ) - if not manifests: - raise ValueError( - "AliMeeting SDM manifests not found in {}".format(manifests_dir) - ) - - with ThreadPoolExecutor(args.num_jobs) as ex: - for part in ["train", "eval", "test"]: - logging.info(f"Processing {part}...") - supervisions_orig = manifests[part]["supervisions"].filter( - lambda s: s.duration >= args.min_segment_duration - ) - futures = [] - - for supervision in tqdm( - supervisions_orig, - desc="Distributing tasks", - ): - futures.append( - ex.submit( - find_recording_and_create_new_supervision, - enhanced_dir, - supervision, - ) - ) - - recordings = [] - supervisions = [] - for future in tqdm( - futures, - total=len(futures), - desc="Processing tasks", - ): - result = future.result() - if result is not None: - recording, new_supervision = result - recordings.append(recording) - supervisions.append(new_supervision) - - # Remove duplicates from the recordings - recordings_nodup = {} - for recording in recordings: - if recording.id not in recordings_nodup: - recordings_nodup[recording.id] = recording - else: - logging.warning("Recording {} is duplicated.".format(recording.id)) - recordings = RecordingSet.from_recordings(recordings_nodup.values()) - supervisions = SupervisionSet.from_segments(supervisions) - - recordings, supervisions = fix_manifests( - recordings=recordings, supervisions=supervisions - ) - - logging.info(f"Writing {part} enhanced manifests") - recordings.to_file( - manifests_dir / f"alimeeting-gss_recordings_{part}.jsonl.gz" - ) - supervisions.to_file( - manifests_dir / f"alimeeting-gss_supervisions_{part}.jsonl.gz" - ) - - -if __name__ == "__main__": - args = get_args() - main(args) diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh deleted file mode 100755 index bd25bc9e5..000000000 --- a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash -# This script is used to run GSS-based enhancement on AMI data. -set -euo pipefail -nj=4 -stage=0 - -. shared/parse_options.sh || exit 1 - -if [ $# != 2 ]; then - echo "Wrong #arguments ($#, expected 2)" - echo "Usage: local/prepare_alimeeting_gss.sh [options] " - echo "e.g. local/prepare_alimeeting_gss.sh data/manifests exp/ami_gss" - echo "main options (for others, see top of script file)" - echo " --nj # number of parallel jobs" - echo " --stage # stage to start running from" - exit 1; -fi - -DATA_DIR=$1 -EXP_DIR=$2 - -mkdir -p $EXP_DIR - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 1 ]; then - log "Stage 1: Prepare cut sets" - for part in train eval test; do - lhotse cut simple \ - -r $DATA_DIR/alimeeting-mdm_recordings_${part}.jsonl.gz \ - -s $DATA_DIR/alimeeting-mdm_supervisions_${part}.jsonl.gz \ - $EXP_DIR/cuts_${part}.jsonl.gz - done -fi - -if [ $stage -le 2 ]; then - log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" - for part in train eval test; do - lhotse cut trim-to-supervisions --discard-overlapping \ - $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz - done -fi - -if [ $stage -le 3 ]; then - log "Stage 3: Split manifests for multi-GPU processing (optional)" - for part in train eval test; do - gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj - done -fi - -if [ $stage -le 4 ]; then - log "Stage 4: Enhance train segments using GSS (requires GPU)" - # for train, we use smaller context and larger batches to speed-up processing - for JOB in $(seq $nj); do - gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ - $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \ - --bss-iterations 10 \ - --context-duration 5.0 \ - --use-garbage-class \ - --channels 0,1,2,3,4,5,6,7 \ - --min-segment-length 0.05 \ - --max-segment-length 25.0 \ - --max-batch-duration 60.0 \ - --num-buckets 4 \ - --num-workers 4 - done -fi - -if [ $stage -le 5 ]; then - log "Stage 5: Enhance eval/test segments using GSS (using GPU)" - # for eval/test, we use larger context and smaller batches to get better quality - for part in eval test; do - for JOB in $(seq $nj); do - gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \ - $EXP_DIR/enhanced \ - --bss-iterations 10 \ - --context-duration 15.0 \ - --use-garbage-class \ - --channels 0,1,2,3,4,5,6,7 \ - --min-segment-length 0.05 \ - --max-segment-length 16.0 \ - --max-batch-duration 45.0 \ - --num-buckets 4 \ - --num-workers 4 - done - done -fi - -if [ $stage -le 6 ]; then - log "Stage 6: Prepare manifests for GSS-enhanced data" - python local/prepare_alimeeting_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05 -fi diff --git a/egs/alimeeting/ASR_v2/local/prepare_char.py b/egs/alimeeting/ASR_v2/local/prepare_char.py deleted file mode 120000 index ee5dd34f1..000000000 --- a/egs/alimeeting/ASR_v2/local/prepare_char.py +++ /dev/null @@ -1 +0,0 @@ -../../ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/prepare_words.py b/egs/alimeeting/ASR_v2/local/prepare_words.py deleted file mode 120000 index 970bfd60c..000000000 --- a/egs/alimeeting/ASR_v2/local/prepare_words.py +++ /dev/null @@ -1 +0,0 @@ -../../ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/text2segments.py b/egs/alimeeting/ASR_v2/local/text2segments.py deleted file mode 120000 index bf4547794..000000000 --- a/egs/alimeeting/ASR_v2/local/text2segments.py +++ /dev/null @@ -1 +0,0 @@ -../../ASR/local/text2segments.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/text2token.py b/egs/alimeeting/ASR_v2/local/text2token.py deleted file mode 120000 index f6b8531b6..000000000 --- a/egs/alimeeting/ASR_v2/local/text2token.py +++ /dev/null @@ -1 +0,0 @@ -../../ASR/local/text2token.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh deleted file mode 100755 index 1881cd75c..000000000 --- a/egs/alimeeting/ASR_v2/prepare.sh +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -stage=-1 -stop_stage=100 -use_gss=true # Use GSS-based enhancement with MDM setting - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/alimeeting -# This directory contains the following files downloaded from -# https://openslr.org/119/ -# -# - Train_Ali_far.tar.gz -# - Train_Ali_near.tar.gz -# - Test_Ali.tar.gz -# - Eval_Ali.tar.gz -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then - lhotse download ali-meeting $dl_dir/alimeeting - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare alimeeting manifest" - # We assume that you have downloaded the alimeeting corpus - # to $dl_dir/alimeeting - for part in ihm sdm mdm; do - mkdir -p data/manifests/alimeeting - lhotse prepare ali-meeting --mic $part --save-mono --normalize-text m2met \ - $dl_dir/alimeeting data/manifests - done -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then - log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)" - # We assume that you have installed the GSS package: https://github.com/desh2608/gss - local/prepare_alimeeting_gss.sh data/manifests exp/alimeeting_gss -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - mkdir -p data/fbank - python local/compute_fbank_musan.py -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for alimeeting" - mkdir -p data/fbank - python local/compute_fbank_alimeeting.py --perturb-speed True - log "Combine features from train splits" - lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ - gzip -c > data/manifests/cuts_train_all.jsonl.gz -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare char based lang" - lang_char_dir=data/lang_char - mkdir -p $lang_char_dir - - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - gunzip -c data/manifests/alimeeting-sdm_supervisions_train.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text - - # Prepare words segments - python ./local/text2segments.py \ - --input $lang_char_dir/text \ - --output $lang_char_dir/text_words_segmentation - - cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \ - | sort -u | sed "/^$/d" \ - | uniq > $lang_char_dir/words_no_ids.txt - - # Prepare words.txt - if [ ! -f $lang_char_dir/words.txt ]; then - ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt - fi - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py - fi -fi diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py deleted file mode 100644 index 9da820315..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import re -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.cut import Cut -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader -from tqdm import tqdm - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AlimeetingAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), - ) - group.add_argument( - "--max-duration", - type=int, - default=100.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), - ) - group.add_argument( - "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch." - ) - group.add_argument( - "--num-buckets", - type=int, - default=50, - help=( - "The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets)." - ), - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help=( - "When enabled (=default), the examples will be " - "shuffled for each epoch." - ), - ) - - group.add_argument( - "--num-workers", - type=int, - default=8, - help=( - "The number of training dataloader workers that " "collect the batches." - ), - ) - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - "Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - if self.args.on_the_fly_feats: - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - ) - else: - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - ) - - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - max_cuts=self.args.max_cuts, - shuffle=False, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures() - ), - return_cuts=True, - ) - sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - def remove_short_cuts(self, cut: Cut) -> bool: - """ - See: https://github.com/k2-fsa/icefall/issues/500 - Basically, the zipformer model subsamples the input using the following formula: - num_out_frames = ((num_in_frames - 7)//2 + 1)//2 - For num_out_frames to be at least 1, num_in_frames must be at least 9. - """ - return cut.duration >= 0.09 - - @lru_cache() - def train_cuts(self, sp: Optional[Any] = None) -> CutSet: - logging.info("About to get AMI train cuts") - - def _remove_short_and_long_utt(c: Cut): - if c.duration < 0.1 or c.duration > 25.0: - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = c.supervisions[0].text - return T >= len(tokens) - - cuts_train = load_manifest_lazy( - self.args.manifest_dir / "cuts_train_all.jsonl.gz" - ) - - return cuts_train.filter(_remove_short_and_long_utt) - - @lru_cache() - def eval_ihm_cuts(self) -> CutSet: - logging.info("About to get AliMeeting IHM eval cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_ihm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def eval_sdm_cuts(self) -> CutSet: - logging.info("About to get AliMeeting SDM eval cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_sdm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def eval_gss_cuts(self) -> CutSet: - if not (self.args.manifest_dir / "cuts_eval_gss.jsonl.gz").exists(): - logging.info("No GSS dev cuts found") - return None - logging.info("About to get AliMeeting GSS-enhanced eval cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_gss.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def test_ihm_cuts(self) -> CutSet: - logging.info("About to get AliMeeting IHM test cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def test_sdm_cuts(self) -> CutSet: - logging.info("About to get AliMeeting SDM test cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def test_gss_cuts(self) -> CutSet: - if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists(): - logging.info("No GSS test cuts found") - return None - logging.info("About to get AliMeeting GSS-enhanced test cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz") - return cs.filter(self.remove_short_cuts) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py deleted file mode 120000 index 37516affc..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py deleted file mode 100755 index 2741e0eeb..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py +++ /dev/null @@ -1,692 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 15 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 15 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 15 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AlimeetingAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest_LG, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AlimeetingAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest_LG", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - if "fast_beam_search" in params.decoding_method: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - alimeeting = AlimeetingAsrDataModule(args) - - eval_ihm_cuts = alimeeting.eval_ihm_cuts() - test_ihm_cuts = alimeeting.test_ihm_cuts() - eval_sdm_cuts = alimeeting.eval_sdm_cuts() - test_sdm_cuts = alimeeting.test_sdm_cuts() - eval_gss_cuts = alimeeting.eval_gss_cuts() - test_gss_cuts = alimeeting.test_gss_cuts() - - eval_ihm_dl = alimeeting.test_dataloaders(eval_ihm_cuts) - test_ihm_dl = alimeeting.test_dataloaders(test_ihm_cuts) - eval_sdm_dl = alimeeting.test_dataloaders(eval_sdm_cuts) - test_sdm_dl = alimeeting.test_dataloaders(test_sdm_cuts) - if eval_gss_cuts is not None: - eval_gss_dl = alimeeting.test_dataloaders(eval_gss_cuts) - if test_gss_cuts is not None: - test_gss_dl = alimeeting.test_dataloaders(test_gss_cuts) - - test_sets = { - "eval_ihm": (eval_ihm_dl, eval_ihm_cuts), - "test_ihm": (test_ihm_dl, test_ihm_cuts), - "eval_sdm": (eval_sdm_dl, eval_sdm_cuts), - "test_sdm": (test_sdm_dl, test_sdm_cuts), - } - if eval_gss_cuts is not None: - test_sets["eval_gss"] = (eval_gss_dl, eval_gss_cuts) - if test_gss_cuts is not None: - test_sets["test_gss"] = (test_gss_dl, test_gss_cuts) - - for test_set in test_sets: - logging.info(f"Decoding {test_set}") - dl, cuts = test_sets[test_set] - results_dict = decode_dataset( - dl=dl, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py deleted file mode 120000 index 0c2673d46..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py deleted file mode 100755 index 8bafaef44..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --tokens ./data/lang_char/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --tokens ./data/lang_char/tokens.txt \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7/decode.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=15, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=8, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.script()") - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py deleted file mode 120000 index a44034e34..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py deleted file mode 120000 index 0d8bc665b..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py deleted file mode 120000 index 068f0f57f..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py deleted file mode 120000 index 7ceac5d10..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py deleted file mode 100755 index 30879d8d2..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ /dev/null @@ -1,1174 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 15 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --max-duration 150 \ - --use-fp16 True - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AlimeetingAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=15, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=5000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 100, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - - y = graph_compiler.texts_to_ids(texts) - if type(y) == list: - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = ((feature_lens - 7) // 2).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - if params.inf_check: - register_inf_check_hooks(model) - - alimeeting = AlimeetingAsrDataModule(args) - - train_cuts = alimeeting.train_cuts() - train_dl = alimeeting.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = alimeeting.eval_ihm_cuts() - valid_dl = alimeeting.valid_dataloaders(valid_cuts) - - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # graph_compiler=graph_compiler, - # params=params, - # ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AlimeetingAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py deleted file mode 120000 index f2f66041e..000000000 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/shared b/egs/alimeeting/ASR_v2/shared deleted file mode 120000 index 3a3b28f96..000000000 --- a/egs/alimeeting/ASR_v2/shared +++ /dev/null @@ -1 +0,0 @@ -../../../egs/aishell/ASR/shared \ No newline at end of file diff --git a/egs/ami/ASR/README.md b/egs/ami/ASR/README.md deleted file mode 100644 index 1c9714bd4..000000000 --- a/egs/ami/ASR/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# AMI - -This is an ASR recipe for the AMI corpus. AMI provides recordings from the speaker's -headset and lapel microphones, and also 2 array microphones containing 8 channels each. -We pool data in the following 4 ways and train a single model on the pooled data: - -(i) individual headset microphone (IHM) -(ii) IHM with simulated reverb -(iii) Single distant microphone (SDM) -(iv) GSS-enhanced array microphones - -Speed perturbation and MUSAN noise augmentation are additionally performed on the pooled -data. Here are the statistics of the combined training data: - -```python ->>> cuts_train.describe() -Cuts count: 1222053 -Total duration (hh:mm:ss): 905:00:28 -Speech duration (hh:mm:ss): 905:00:28 (99.9%) -Duration statistics (seconds): -mean 2.7 -std 2.8 -min 0.0 -25% 0.6 -50% 1.6 -75% 3.8 -99% 12.3 -99.5% 13.9 -99.9% 18.4 -max 36.8 -``` - -**Note:** This recipe additionally uses [GSS](https://github.com/desh2608/gss) for enhancement -of far-field array microphones, but this is optional (see `prepare.sh` for details). - -## Performance Record - -### pruned_transducer_stateless7 - -The following are decoded using `modified_beam_search`: - -| Evaluation set | dev WER | test WER | -|--------------------------|------------|---------| -| IHM | 18.92 | 17.40 | -| SDM | 31.25 | 32.21 | -| MDM (GSS-enhanced) | 21.67 | 22.43 | - -See [RESULTS](/egs/ami/ASR/RESULTS.md) for details. diff --git a/egs/ami/ASR/RESULTS.md b/egs/ami/ASR/RESULTS.md deleted file mode 100644 index 163986021..000000000 --- a/egs/ami/ASR/RESULTS.md +++ /dev/null @@ -1,92 +0,0 @@ -## Results - -### AMI training results (Pruned Transducer) - -#### 2022-11-20 - -#### Zipformer (pruned_transducer_stateless7) - -Zipformer encoder + non-current decoder. The decoder -contains only an embedding layer, a Conv1d (with kernel size 2) and a linear -layer (to transform tensor dim). - -All the results below are using a single model that is trained by combining the following -data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise -augmentation are applied on top of the pooled data. - -**WERs for IHM:** - -| | dev | test | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 19.25 | 17.83 | --epoch 14 --avg 8 --max-duration 500 | -| modified beam search | 18.92 | 17.40 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | -| fast beam search | 19.44 | 18.04 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | - -**WERs for SDM:** - -| | dev | test | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 31.32 | 32.38 | --epoch 14 --avg 8 --max-duration 500 | -| modified beam search | 31.25 | 32.21 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | -| fast beam search | 31.11 | 32.10 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | - -**WERs for GSS-enhanced MDM:** - -| | dev | test | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 22.05 | 22.93 | --epoch 14 --avg 8 --max-duration 500 | -| modified beam search | 21.67 | 22.43 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | -| fast beam search | 22.21 | 22.83 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 15 \ - --exp-dir pruned_transducer_stateless7/exp \ - --max-duration 150 \ - --max-cuts 150 \ - --prune-range 5 \ - --lr-factor 5 \ - --lm-scale 0.25 \ - --use-fp16 True -``` - -The decoding command is: -``` -# greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 14 \ - --avg 8 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method greedy_search - -# modified beam search -./pruned_transducer_stateless7/decode.py \ - --iter 105000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -# fast beam search -./pruned_transducer_stateless7/decode.py \ - --iter 105000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -Pretrained model is available at - -The tensorboard training log can be found at - diff --git a/egs/ami/ASR/local/__init__.py b/egs/ami/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/ami/ASR/local/compute_fbank_ami.py b/egs/ami/ASR/local/compute_fbank_ami.py deleted file mode 100755 index 4892b40e3..000000000 --- a/egs/ami/ASR/local/compute_fbank_ami.py +++ /dev/null @@ -1,194 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the AMI dataset. -For the training data, we pool together IHM, reverberated IHM, and GSS-enhanced -audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced -parts (which are the 3 evaluation settings). -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" -import logging -import math -from pathlib import Path - -import torch -import torch.multiprocessing -from lhotse import CutSet, LilcomChunkyWriter -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_ami(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - manifests_ihm = read_manifests_if_cached( - dataset_parts=["train", "dev", "test"], - output_dir=src_dir, - prefix="ami-ihm", - suffix="jsonl.gz", - ) - manifests_sdm = read_manifests_if_cached( - dataset_parts=["train", "dev", "test"], - output_dir=src_dir, - prefix="ami-sdm", - suffix="jsonl.gz", - ) - # For GSS we already have cuts so we read them directly. - manifests_gss = read_manifests_if_cached( - dataset_parts=["train", "dev", "test"], - output_dir=src_dir, - prefix="ami-gss", - suffix="jsonl.gz", - ) - - def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None: - cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) - _ = cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=storage_path, - manifest_path=manifest_path, - batch_duration=5000, - num_workers=8, - storage_type=LilcomChunkyWriter, - ) - - logging.info( - "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)" - ) - - logging.info("Processing train split IHM") - cuts_ihm = ( - CutSet.from_manifests(**manifests_ihm["train"]) - .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) - .modify_ids(lambda x: x + "-ihm") - ) - _extract_feats( - cuts_ihm, - output_dir / "feats_train_ihm", - src_dir / "cuts_train_ihm.jsonl.gz", - ) - - logging.info("Processing train split IHM + reverberated IHM") - cuts_ihm_rvb = cuts_ihm.reverb_rir() - _extract_feats( - cuts_ihm_rvb, - output_dir / "feats_train_ihm_rvb", - src_dir / "cuts_train_ihm_rvb.jsonl.gz", - ) - - logging.info("Processing train split SDM") - cuts_sdm = ( - CutSet.from_manifests(**manifests_sdm["train"]) - .trim_to_supervisions(keep_overlapping=False) - .modify_ids(lambda x: x + "-sdm") - ) - _extract_feats( - cuts_sdm, - output_dir / "feats_train_sdm", - src_dir / "cuts_train_sdm.jsonl.gz", - ) - - logging.info("Processing train split GSS") - cuts_gss = ( - CutSet.from_manifests(**manifests_gss["train"]) - .trim_to_supervisions(keep_overlapping=False) - .modify_ids(lambda x: x + "-gss") - ) - _extract_feats( - cuts_gss, - output_dir / "feats_train_gss", - src_dir / "cuts_train_gss.jsonl.gz", - ) - - logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") - for split in ["dev", "test"]: - logging.info(f"Processing {split} IHM") - cuts_ihm = ( - CutSet.from_manifests(**manifests_ihm[split]) - .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_{split}_ihm", - manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz", - batch_duration=5000, - num_workers=8, - storage_type=LilcomChunkyWriter, - ) - ) - logging.info(f"Processing {split} SDM") - cuts_sdm = ( - CutSet.from_manifests(**manifests_sdm[split]) - .trim_to_supervisions(keep_overlapping=False) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_{split}_sdm", - manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - ) - logging.info(f"Processing {split} GSS") - cuts_gss = ( - CutSet.from_manifests(**manifests_gss[split]) - .trim_to_supervisions(keep_overlapping=False) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_{split}_gss", - manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_ami() diff --git a/egs/ami/ASR/local/compute_fbank_musan.py b/egs/ami/ASR/local/compute_fbank_musan.py deleted file mode 100755 index 1fcf951f9..000000000 --- a/egs/ami/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -from pathlib import Path - -import torch -from lhotse import CutSet, LilcomChunkyWriter, combine -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = src_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - # create chunks of Musan with duration 5 - 10 seconds - _ = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / "musan_feats", - manifest_path=musan_cuts_path, - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() diff --git a/egs/ami/ASR/local/prepare_ami_enhanced.py b/egs/ami/ASR/local/prepare_ami_enhanced.py deleted file mode 100644 index bed220eb3..000000000 --- a/egs/ami/ASR/local/prepare_ami_enhanced.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/local/bin/python -# -*- coding: utf-8 -*- -# Data preparation for AMI GSS-enhanced dataset. - -import logging -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path - -from lhotse import Recording, RecordingSet, SupervisionSet -from lhotse.qa import fix_manifests -from lhotse.recipes.utils import read_manifests_if_cached -from lhotse.utils import fastcopy -from tqdm import tqdm - -logging.basicConfig( - format="%(asctime)s %(levelname)-8s %(message)s", - level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S", -) - - -def get_args(): - import argparse - - parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.") - parser.add_argument( - "manifests_dir", - type=Path, - help="Path to directory containing AMI manifests.", - ) - parser.add_argument( - "enhanced_dir", - type=Path, - help="Path to enhanced data directory.", - ) - parser.add_argument( - "--num-jobs", - "-j", - type=int, - default=1, - help="Number of parallel jobs to run.", - ) - parser.add_argument( - "--min-segment-duration", - "-d", - type=float, - default=0.0, - help="Minimum duration of a segment in seconds.", - ) - return parser.parse_args() - - -def find_recording_and_create_new_supervision(enhanced_dir, supervision): - """ - Given a supervision (corresponding to original AMI recording), this function finds the - enhanced recording correspoding to the supervision, and returns this recording and - a new supervision whose start and end times are adjusted to match the enhanced recording. - """ - file_name = Path( - f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac" - ) - save_path = enhanced_dir / f"{supervision.recording_id}" / file_name - if save_path.exists(): - recording = Recording.from_file(save_path) - if recording.duration == 0: - logging.warning(f"Skipping {save_path} which has duration 0 seconds.") - return None - - # Old supervision is wrt to the original recording, we create new supervision - # wrt to the enhanced segment - new_supervision = fastcopy( - supervision, - recording_id=recording.id, - start=0, - duration=recording.duration, - ) - return recording, new_supervision - else: - logging.warning(f"{save_path} does not exist.") - return None - - -def main(args): - # Get arguments - manifests_dir = args.manifests_dir - enhanced_dir = args.enhanced_dir - - # Load manifests from cache if they exist (saves time) - manifests = read_manifests_if_cached( - dataset_parts=["train", "dev", "test"], - output_dir=manifests_dir, - prefix="ami-sdm", - suffix="jsonl.gz", - ) - if not manifests: - raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir)) - - with ThreadPoolExecutor(args.num_jobs) as ex: - for part in ["train", "dev", "test"]: - logging.info(f"Processing {part}...") - supervisions_orig = manifests[part]["supervisions"].filter( - lambda s: s.duration >= args.min_segment_duration - ) - # Remove TS3009d supervisions since they are not present in the enhanced data - supervisions_orig = supervisions_orig.filter( - lambda s: s.recording_id != "TS3009d" - ) - futures = [] - - for supervision in tqdm( - supervisions_orig, - desc="Distributing tasks", - ): - futures.append( - ex.submit( - find_recording_and_create_new_supervision, - enhanced_dir, - supervision, - ) - ) - - recordings = [] - supervisions = [] - for future in tqdm( - futures, - total=len(futures), - desc="Processing tasks", - ): - result = future.result() - if result is not None: - recording, new_supervision = result - recordings.append(recording) - supervisions.append(new_supervision) - - # Remove duplicates from the recordings - recordings_nodup = {} - for recording in recordings: - if recording.id not in recordings_nodup: - recordings_nodup[recording.id] = recording - else: - logging.warning("Recording {} is duplicated.".format(recording.id)) - recordings = RecordingSet.from_recordings(recordings_nodup.values()) - supervisions = SupervisionSet.from_segments(supervisions) - - recordings, supervisions = fix_manifests( - recordings=recordings, supervisions=supervisions - ) - - logging.info(f"Writing {part} enhanced manifests") - recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz") - supervisions.to_file( - manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz" - ) - - -if __name__ == "__main__": - args = get_args() - main(args) diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh deleted file mode 100755 index 414c22b12..000000000 --- a/egs/ami/ASR/local/prepare_ami_gss.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash -# This script is used to run GSS-based enhancement on AMI data. -set -euo pipefail -nj=4 -stage=0 - -. shared/parse_options.sh || exit 1 - -if [ $# != 2 ]; then - echo "Wrong #arguments ($#, expected 2)" - echo "Usage: local/prepare_ami_gss.sh [options] " - echo "e.g. local/prepare_ami_gss.sh data/manifests exp/ami_gss" - echo "main options (for others, see top of script file)" - echo " --nj # number of parallel jobs" - echo " --stage # stage to start running from" - exit 1; -fi - -DATA_DIR=$1 -EXP_DIR=$2 - -mkdir -p $EXP_DIR - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 1 ]; then - log "Stage 1: Prepare cut sets" - for part in train dev test; do - lhotse cut simple \ - -r $DATA_DIR/ami-mdm_recordings_${part}.jsonl.gz \ - -s $DATA_DIR/ami-mdm_supervisions_${part}.jsonl.gz \ - $EXP_DIR/cuts_${part}.jsonl.gz - done -fi - -if [ $stage -le 2 ]; then - log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" - for part in train dev test; do - lhotse cut trim-to-supervisions --discard-overlapping \ - $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz - done -fi - -if [ $stage -le 3 ]; then - log "Stage 3: Split manifests for multi-GPU processing (optional)" - for part in train; do - gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj - done -fi - -if [ $stage -le 4 ]; then - log "Stage 4: Enhance train segments using GSS (requires GPU)" - # for train, we use smaller context and larger batches to speed-up processing - for JOB in $(seq $nj); do - gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ - $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \ - --bss-iterations 10 \ - --context-duration 5.0 \ - --use-garbage-class \ - --channels 0,1,2,3,4,5,6,7 \ - --min-segment-length 0.05 \ - --max-segment-length 35.0 \ - --max-batch-duration 60.0 \ - --num-buckets 3 \ - --num-workers 2 - done -fi - -if [ $stage -le 5 ]; then - log "Stage 5: Enhance dev/test segments using GSS (using GPU)" - # for dev/test, we use larger context and smaller batches to get better quality - for part in dev test; do - for JOB in $(seq $nj); do - gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \ - $EXP_DIR/enhanced \ - --bss-iterations 10 \ - --context-duration 15.0 \ - --use-garbage-class \ - --channels 0,1,2,3,4,5,6,7 \ - --min-segment-length 0.05 \ - --max-segment-length 30.0 \ - --max-batch-duration 45.0 \ - --num-buckets 3 \ - --num-workers 2 - done - done -fi - -if [ $stage -le 6 ]; then - log "Stage 6: Prepare manifests for GSS-enhanced data" - python local/prepare_ami_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05 -fi diff --git a/egs/ami/ASR/local/prepare_lang_bpe.py b/egs/ami/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/ami/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/ami/ASR/local/train_bpe_model.py b/egs/ami/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/ami/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ami/ASR/prepare.sh b/egs/ami/ASR/prepare.sh deleted file mode 100755 index fb21a8ec6..000000000 --- a/egs/ami/ASR/prepare.sh +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -stage=-1 -stop_stage=100 -use_gss=true # Use GSS-based enhancement with MDM setting - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/amicorpus -# You can find audio and transcripts in this path. -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -# -# - $dl_dir/{LDC2004S13,LDC2005S13,LDC2004T19,LDC2005T19} -# These contain the Fisher English audio and transcripts. We will -# only use the transcripts as extra LM training data (similar to Kaldi). -# -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data -vocab_size=500 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/amicorpus, - # you can create a symlink - # - # ln -sfv /path/to/amicorpus $dl_dir/amicorpus - # - if [ ! -d $dl_dir/amicorpus ]; then - lhotse download ami --mic ihm $dl_dir/amicorpus - lhotse download ami --mic mdm $dl_dir/amicorpus - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare AMI manifests" - # We assume that you have downloaded the AMI corpus - # to $dl_dir/amicorpus. We perform text normalization for the transcripts. - mkdir -p data/manifests - for mic in ihm sdm mdm; do - lhotse prepare ami --mic $mic --partition full-corpus-asr --normalize-text kaldi \ - --max-words-per-segment 30 $dl_dir/amicorpus data/manifests/ - done -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then - log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)" - # We assume that you have installed the GSS package: https://github.com/desh2608/gss - local/prepare_ami_gss.sh data/manifests exp/ami_gss -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank features for AMI" - mkdir -p data/fbank - python local/compute_fbank_ami.py - log "Combine features from train splits" - lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ - gzip -c > data/manifests/cuts_train_all.jsonl.gz -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank features for musan" - mkdir -p data/fbank - python local/compute_fbank_musan.py -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Dump transcripts for BPE model training." - mkdir -p data/lm - cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g')> data/lm/transcript_words.txt -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare BPE based lang" - - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - - # Add special words to words.txt - echo " 0" > $lang_dir/words.txt - echo "!SIL 1" >> $lang_dir/words.txt - echo " 2" >> $lang_dir/words.txt - - # Add regular words to words.txt - cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt - - # Add remaining special word symbols expected by LM scripts. - num_words=$(cat $lang_dir/words.txt | wc -l) - echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(cat $lang_dir/words.txt | wc -l) - echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(cat $lang_dir/words.txt | wc -l) - echo "#0 ${num_words}" >> $lang_dir/words.txt - - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript data/lm/transcript_words.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - fi -fi diff --git a/egs/ami/ASR/pruned_transducer_stateless7/__init__.py b/egs/ami/ASR/pruned_transducer_stateless7/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py deleted file mode 100644 index 554facfc1..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ /dev/null @@ -1,432 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import re -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.cut import Cut -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader -from tqdm import tqdm - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AmiAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), - ) - group.add_argument( - "--max-duration", - type=int, - default=100.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), - ) - group.add_argument( - "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch." - ) - group.add_argument( - "--num-buckets", - type=int, - default=50, - help=( - "The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets)." - ), - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help=( - "When enabled (=default), the examples will be " - "shuffled for each epoch." - ), - ) - - group.add_argument( - "--num-workers", - type=int, - default=8, - help=( - "The number of training dataloader workers that " "collect the batches." - ), - ) - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), - ) - group.add_argument( - "--ihm-only", - type=str2bool, - default=False, - help="When enabled, only use IHM data for training.", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - "Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - if self.args.on_the_fly_feats: - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - ) - else: - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - ) - - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - max_cuts=self.args.max_cuts, - shuffle=False, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=True, - ) - sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - def remove_short_cuts(self, cut: Cut) -> bool: - """ - See: https://github.com/k2-fsa/icefall/issues/500 - Basically, the zipformer model subsamples the input using the following formula: - num_out_frames = (num_in_frames - 7)//2 - For num_out_frames to be at least 1, num_in_frames must be at least 9. - """ - return cut.duration >= 0.09 - - @lru_cache() - def train_cuts(self, sp: Optional[Any] = None) -> CutSet: - logging.info("About to get AMI train cuts") - - def _remove_short_and_long_utt(c: Cut): - if c.duration < 0.2 or c.duration > 25.0: - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - return T >= len(tokens) - - if self.args.ihm_only: - cuts_train = load_manifest_lazy( - self.args.manifest_dir / "cuts_train_ihm.jsonl.gz" - ) - else: - cuts_train = load_manifest_lazy( - self.args.manifest_dir / "cuts_train_all.jsonl.gz" - ) - - return cuts_train.filter(_remove_short_and_long_utt) - - @lru_cache() - def dev_ihm_cuts(self) -> CutSet: - logging.info("About to get AMI IHM dev cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def dev_sdm_cuts(self) -> CutSet: - logging.info("About to get AMI SDM dev cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def dev_gss_cuts(self) -> CutSet: - if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists(): - logging.info("No GSS dev cuts found") - return None - logging.info("About to get AMI GSS-enhanced dev cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_gss.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def test_ihm_cuts(self) -> CutSet: - logging.info("About to get AMI IHM test cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def test_sdm_cuts(self) -> CutSet: - logging.info("About to get AMI SDM test cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz") - return cs.filter(self.remove_short_cuts) - - @lru_cache() - def test_gss_cuts(self) -> CutSet: - if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists(): - logging.info("No GSS test cuts found") - return None - logging.info("About to get AMI GSS-enhanced test cuts") - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz") - return cs.filter(self.remove_short_cuts) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py deleted file mode 120000 index 37516affc..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py deleted file mode 100755 index 9999894d1..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ /dev/null @@ -1,739 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7/decode.py \ - --iter 105000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) beam search -./pruned_transducer_stateless7/decode.py \ - --iter 105000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7/decode.py \ - --iter 105000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 500 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless7/decode.py \ - --iter 105000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AmiAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest_LG, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, - word_table: Optional[k2.SymbolTable] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - word_table: - The word symbol table. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, - word_table: Optional[k2.SymbolTable] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - test_set_cers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" - with open(wers_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - # we also compute CER for AMI dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" - with open(cers_filename, "w") as f: - cer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True - ) - test_set_cers[key] = cer - - logging.info("Wrote detailed error stats to {}".format(wers_filename)) - - test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} - test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER\tCER", file=f) - for key in test_set_wers: - print( - "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]), - file=f, - ) - - s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key in test_set_wers: - s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AmiAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest_LG", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(f"{params.lang_dir}/bpe.model") - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - ami = AmiAsrDataModule(args) - - dev_ihm_cuts = ami.dev_ihm_cuts() - test_ihm_cuts = ami.test_ihm_cuts() - dev_sdm_cuts = ami.dev_sdm_cuts() - test_sdm_cuts = ami.test_sdm_cuts() - dev_gss_cuts = ami.dev_gss_cuts() - test_gss_cuts = ami.test_gss_cuts() - - dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts) - test_ihm_dl = ami.test_dataloaders(test_ihm_cuts) - dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts) - test_sdm_dl = ami.test_dataloaders(test_sdm_cuts) - if dev_gss_cuts is not None: - dev_gss_dl = ami.test_dataloaders(dev_gss_cuts) - if test_gss_cuts is not None: - test_gss_dl = ami.test_dataloaders(test_gss_cuts) - - test_sets = { - "dev_ihm": (dev_ihm_dl, dev_ihm_cuts), - "test_ihm": (test_ihm_dl, test_ihm_cuts), - "dev_sdm": (dev_sdm_dl, dev_sdm_cuts), - "test_sdm": (test_sdm_dl, test_sdm_cuts), - } - if dev_gss_cuts is not None: - test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts) - if test_gss_cuts is not None: - test_sets["test_gss"] = (test_gss_dl, test_gss_cuts) - - for test_set in test_sets: - logging.info(f"Decoding {test_set}") - dl, cuts = test_sets[test_set] - results_dict = decode_dataset( - dl=dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decoder.py b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py deleted file mode 120000 index 0c2673d46..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/export.py b/egs/ami/ASR/pruned_transducer_stateless7/export.py deleted file mode 120000 index 2713792e6..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/joiner.py b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/model.py b/egs/ami/ASR/pruned_transducer_stateless7/model.py deleted file mode 120000 index 0d8bc665b..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/optim.py b/egs/ami/ASR/pruned_transducer_stateless7/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py deleted file mode 100755 index d62cdadb7..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ /dev/null @@ -1,1181 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 15 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --max-duration 150 \ - --use-fp16 True - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AmiAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=11, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=5000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 100, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"] - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = supervisions["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = ((feature_lens - 7) // 2).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - if params.inf_check: - register_inf_check_hooks(model) - - ami = AmiAsrDataModule(args) - - # Here is the duration statistics of the training set. - # Cuts count: 1230033 - # Total duration (hh:mm:ss): 904:25:34 - # Speech duration (hh:mm:ss): 904:25:34 (100.0%) - # Duration statistics (seconds): - # mean 2.6 - # std 2.8 - # min 0.0 - # 25% 0.6 - # 50% 1.6 - # 75% 3.8 - # 99% 12.3 - # 99.5% 13.9 - # 99.9% 18.3 - # max 36.8 - - train_cuts = ami.train_cuts(sp=sp) - train_dl = ami.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) - - valid_cuts = ami.dev_ihm_cuts() - valid_dl = ami.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AmiAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py deleted file mode 120000 index f2f66041e..000000000 --- a/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/ami/ASR/shared b/egs/ami/ASR/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/ami/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/ami/SURT/README.md b/egs/ami/SURT/README.md deleted file mode 100644 index 74a8ba014..000000000 --- a/egs/ami/SURT/README.md +++ /dev/null @@ -1,156 +0,0 @@ -# Introduction - -This is a multi-talker ASR recipe for the AMI and ICSI datasets. We train a Streaming -Unmixing and Recognition Transducer (SURT) model for the task. - -Please refer to the `egs/libricss/SURT` recipe README for details about the task and the -model. - -## Description of the recipe - -### Pre-requisites - -The recipes in this directory need the following packages to be installed: - -- [meeteval](https://github.com/fgnt/meeteval) -- [einops](https://github.com/arogozhnikov/einops) - -Additionally, we initialize the model with the pre-trained model from the LibriCSS recipe. -Please download this checkpoint (see below) or train the LibriCSS recipe first. - -### Training - -To train the model, run the following from within `egs/ami/SURT`: - -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -python dprnn_zipformer/train.py \ - --use-fp16 True \ - --exp-dir dprnn_zipformer/exp/surt_base \ - --world-size 4 \ - --max-duration 500 \ - --max-duration-valid 250 \ - --max-cuts 200 \ - --num-buckets 50 \ - --num-epochs 30 \ - --enable-spec-aug True \ - --enable-musan False \ - --ctc-loss-scale 0.2 \ - --heat-loss-scale 0.2 \ - --base-lr 0.004 \ - --model-init-ckpt exp/libricss_base.pt \ - --chunk-width-randomization True \ - --num-mask-encoder-layers 4 \ - --num-encoder-layers 2,2,2,2,2 -``` - -The above is for SURT-base (~26M). For SURT-large (~38M), use: - -```bash - --model-init-ckpt exp/libricss_large.pt \ - --num-mask-encoder-layers 6 \ - --num-encoder-layers 2,4,3,2,4 \ - --model-init-ckpt exp/zipformer_large.pt \ -``` - -**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM. - -### Adaptation - -The training step above only trains on simulated mixtures. For best results, we also -adapt the final model on the AMI+ICSI train set. For this, run the following from within -`egs/ami/SURT`: - -```bash -export CUDA_VISIBLE_DEVICES="0" - -python dprnn_zipformer/train_adapt.py \ - --use-fp16 True \ - --exp-dir dprnn_zipformer/exp/surt_base_adapt \ - --world-size 4 \ - --max-duration 500 \ - --max-duration-valid 250 \ - --max-cuts 200 \ - --num-buckets 50 \ - --num-epochs 8 \ - --lr-epochs 2 \ - --enable-spec-aug True \ - --enable-musan False \ - --ctc-loss-scale 0.2 \ - --base-lr 0.0004 \ - --model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \ - --chunk-width-randomization True \ - --num-mask-encoder-layers 4 \ - --num-encoder-layers 2,2,2,2,2 -``` - -For SURT-large, use the following config: - -```bash - --num-mask-encoder-layers 6 \ - --num-encoder-layers 2,4,3,2,4 \ - --model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \ - --num-epochs 15 \ - --lr-epochs 4 \ -``` - - -### Decoding - -To decode the model, run the following from within `egs/ami/SURT`: - -#### Greedy search - -```bash -export CUDA_VISIBLE_DEVICES="0" - -python dprnn_zipformer/decode.py \ - --epoch 20 --avg 1 --use-averaged-model False \ - --exp-dir dprnn_zipformer/exp/surt_base_adapt \ - --max-duration 250 \ - --decoding-method greedy_search -``` - -#### Beam search - -```bash -python dprnn_zipformer/decode.py \ - --epoch 20 --avg 1 --use-averaged-model False \ - --exp-dir dprnn_zipformer/exp/surt_base_adapt \ - --max-duration 250 \ - --decoding-method modified_beam_search \ - --beam-size 4 -``` - -## Results (using beam search) - -**AMI** - -| Model | IHM-Mix | SDM | MDM | -|------------|:-------:|:----:|:----:| -| SURT-base | 39.8 | 65.4 | 46.6 | -| + adapt | 37.4 | 46.9 | 43.7 | -| SURT-large | 36.8 | 62.5 | 44.4 | -| + adapt | **35.1** | **44.6** | **41.4** | - -**ICSI** - -| Model | IHM-Mix | SDM | -|------------|:-------:|:----:| -| SURT-base | 28.3 | 60.0 | -| + adapt | 26.3 | 33.9 | -| SURT-large | 27.8 | 59.7 | -| + adapt | **24.4** | **32.3** | - -## Pre-trained models and logs - -* LibriCSS pre-trained model (for initialization): [base](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_base) [large](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_large) - -* Pre-trained models: - -* Training logs: - - surt_base: - - surt_base_adapt: - - surt_large: - - surt_large_adapt: diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py deleted file mode 100644 index ea8b62242..000000000 --- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py +++ /dev/null @@ -1,401 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutMix, - DynamicBucketingSampler, - K2SurtDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AmiAsrDataModule: - """ - DataModule for k2 SURT experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--max-duration-valid", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--max-cuts", - type=int, - default=100, - help="Maximum number of cuts in a single batch. You can " - "reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - sources: bool = False, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SurtDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - return_sources=sources, - strict=False, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - quadratic_duration=30.0, - max_cuts=self.args.max_cuts, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - max_cuts=self.args.max_cuts, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - - logging.info("About to create dev dataset") - validate = K2SurtDataset( - input_strategy=OnTheFlyFeatures( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - return_sources=False, - strict=False, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration_valid, - quadratic_duration=30.0, - max_cuts=self.args.max_cuts, - shuffle=False, - ) - logging.info("About to create dev dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SurtDataset( - input_strategy=OnTheFlyFeatures( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - return_sources=False, - strict=False, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration_valid, - max_cuts=self.args.max_cuts, - shuffle=False, - ) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - return test_dl - - @lru_cache() - def aimix_train_cuts( - self, - rvb_affix: str = "clean", - sources: bool = True, - ) -> CutSet: - logging.info("About to get train cuts") - source_affix = "_sources" if sources else "" - cs = load_manifest_lazy( - self.args.manifest_dir / f"cuts_train_{rvb_affix}{source_affix}.jsonl.gz" - ) - cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0) - return cs - - @lru_cache() - def train_cuts( - self, - ) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train_ami_icsi.jsonl.gz" - ) - - @lru_cache() - def ami_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet: - logging.info(f"About to get AMI {split} {type} cuts") - return load_manifest_lazy( - self.args.manifest_dir / f"cuts_ami-{type}_{split}.jsonl.gz" - ) - - @lru_cache() - def icsi_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet: - logging.info(f"About to get ICSI {split} {type} cuts") - return load_manifest_lazy( - self.args.manifest_dir / f"cuts_icsi-{type}_{split}.jsonl.gz" - ) diff --git a/egs/ami/SURT/dprnn_zipformer/beam_search.py b/egs/ami/SURT/dprnn_zipformer/beam_search.py deleted file mode 120000 index 581b29833..000000000 --- a/egs/ami/SURT/dprnn_zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/beam_search.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/decode.py b/egs/ami/SURT/dprnn_zipformer/decode.py deleted file mode 100755 index d1a1eddc9..000000000 --- a/egs/ami/SURT/dprnn_zipformer/decode.py +++ /dev/null @@ -1,622 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./dprnn_zipformer/decode.py \ - --epoch 20 \ - --avg 1 \ - --use-averaged-model false \ - --exp-dir ./dprnn_zipformer/exp_adapt \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./dprnn_zipformer/decode.py \ - --epoch 20 \ - --avg 1 \ - --use-averaged-model false \ - --exp-dir ./dprnn_zipformer/exp_adapt \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./dprnn_zipformer/decode.py \ - --epoch 20 \ - --avg 1 \ - --use-averaged-model false \ - --exp-dir ./dprnn_zipformer/exp_adapt \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AmiAsrDataModule -from beam_search import ( - beam_search, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.utils import EPSILON -from train import add_model_arguments, get_params, get_surt_model - -from icefall import LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_surt_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="dprnn_zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - feature_lens = batch["input_lens"].to(device) - - # Apply the mask encoder - B, T, F = feature.shape - processed = model.mask_encoder(feature) # B,T,F*num_channels - masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) - x_masked = [feature * m for m in masks] - - # Recognition - # Stack the inputs along the batch axis - h = torch.cat(x_masked, dim=0) - h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) - encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) - - if model.joint_encoder_layer is not None: - encoder_out = model.joint_encoder_layer(encoder_out) - - def _group_channels(hyps: List[str]) -> List[List[str]]: - """ - Currently we have a batch of size M*B, where M is the number of - channels and B is the batch size. We need to group the hypotheses - into B groups, each of which contains M hypotheses. - - Example: - hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2'] - _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']] - """ - assert len(hyps) == B * params.num_channels - out_hyps = [] - for i in range(B): - out_hyps.append(hyps[i::B]) - return out_hyps - - hyps = [] - if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp)) - - if params.decoding_method == "greedy_search": - return {"greedy_search": _group_channels(hyps)} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: _group_channels(hyps)} - else: - return {f"beam_size_{params.beam_size}": _group_channels(hyps)} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - cut_ids = [cut.id for cut in batch["cuts"]] - cuts_batch = batch["cuts"] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - for cut_id, hyp_words in zip(cut_ids, hyps): - # Reference is a list of supervision texts sorted by start time. - ref_words = [ - s.text.strip() - for s in sorted( - cuts_batch[cut_id].supervisions, key=lambda s: s.start - ) - ] - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(cut_ids) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_surt_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - num_channels=params.num_channels, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LmScorer.add_arguments(parser) - AmiAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - ), f"Decoding method {params.decoding_method} is not supported." - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - ami = AmiAsrDataModule(args) - - # NOTE(@desh2608): we filter segments longer than 120s to avoid OOM errors in decoding. - # However, 99.9% of the segments are shorter than 120s, so this should not - # substantially affect the results. In future, we will implement an overlapped - # inference method to avoid OOM errors. - - test_sets = {} - for split in ["dev", "test"]: - for type in ["ihm-mix", "sdm", "mdm8-bf"]: - test_sets[f"ami-{split}_{type}"] = ( - ami.ami_cuts(split=split, type=type) - .trim_to_supervision_groups(max_pause=0.0) - .filter(lambda c: 0.1 < c.duration < 120.0) - .to_eager() - ) - - for split in ["dev", "test"]: - for type in ["ihm-mix", "sdm"]: - test_sets[f"icsi-{split}_{type}"] = ( - ami.icsi_cuts(split=split, type=type) - .trim_to_supervision_groups(max_pause=0.0) - .filter(lambda c: 0.1 < c.duration < 120.0) - .to_eager() - ) - - for test_set, test_cuts in test_sets.items(): - test_dl = ami.test_dataloaders(test_cuts) - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/ami/SURT/dprnn_zipformer/decoder.py b/egs/ami/SURT/dprnn_zipformer/decoder.py deleted file mode 120000 index c34865c25..000000000 --- a/egs/ami/SURT/dprnn_zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/decoder.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/dprnn.py b/egs/ami/SURT/dprnn_zipformer/dprnn.py deleted file mode 120000 index 8918beb32..000000000 --- a/egs/ami/SURT/dprnn_zipformer/dprnn.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/dprnn.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/encoder_interface.py b/egs/ami/SURT/dprnn_zipformer/encoder_interface.py deleted file mode 120000 index 0ba945d0f..000000000 --- a/egs/ami/SURT/dprnn_zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/export.py b/egs/ami/SURT/dprnn_zipformer/export.py deleted file mode 120000 index 3deae4471..000000000 --- a/egs/ami/SURT/dprnn_zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/export.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/joiner.py b/egs/ami/SURT/dprnn_zipformer/joiner.py deleted file mode 120000 index 79fbe8769..000000000 --- a/egs/ami/SURT/dprnn_zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/joiner.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/model.py b/egs/ami/SURT/dprnn_zipformer/model.py deleted file mode 120000 index ae8c65c99..000000000 --- a/egs/ami/SURT/dprnn_zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/model.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/optim.py b/egs/ami/SURT/dprnn_zipformer/optim.py deleted file mode 120000 index 366d0f7a2..000000000 --- a/egs/ami/SURT/dprnn_zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/optim.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/scaling.py b/egs/ami/SURT/dprnn_zipformer/scaling.py deleted file mode 120000 index f11d49d77..000000000 --- a/egs/ami/SURT/dprnn_zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/scaling.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/scaling_converter.py b/egs/ami/SURT/dprnn_zipformer/scaling_converter.py deleted file mode 120000 index 1533cbe0e..000000000 --- a/egs/ami/SURT/dprnn_zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/test_model.py b/egs/ami/SURT/dprnn_zipformer/test_model.py deleted file mode 120000 index 1259849e0..000000000 --- a/egs/ami/SURT/dprnn_zipformer/test_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py deleted file mode 100755 index adc6a8495..000000000 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ /dev/null @@ -1,1419 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -cd egs/ami/SURT/ -./prepare.sh - -./dprnn_zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir dprnn_zipformer/exp \ - --max-duration 650 -""" - -import argparse -import copy -import logging -import warnings -from itertools import chain -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AmiAsrDataModule -from decoder import Decoder -from dprnn import DPRNN -from einops.layers.torch import Rearrange -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import LOG_EPSILON, fix_random_seed -from model import SURT -from optim import Eden, ScaledAdam -from scaling import ScaledLinear, ScaledLSTM -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-mask-encoder-layers", - type=int, - default=4, - help="Number of layers in the DPRNN based mask encoder.", - ) - - parser.add_argument( - "--mask-encoder-dim", - type=int, - default=256, - help="Hidden dimension of the LSTM blocks in DPRNN.", - ) - - parser.add_argument( - "--mask-encoder-segment-size", - type=int, - default=32, - help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " - "decode-chunk-length of the zipformer encoder.", - ) - - parser.add_argument( - "--chunk-width-randomization", - type=bool, - default=False, - help="Whether to randomize the chunk width in DPRNN.", - ) - - # Zipformer config is based on: - # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,2,2,2", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="768,768,768,768,768", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="256,256,256,256,256", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="192,192,192,192,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--use-joint-encoder-layer", - type=str, - default="lstm", - choices=["linear", "lstm", "none"], - help="Whether to use a joint layer to combine all branches.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conv_lstm_transducer_stateless_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--model-init-ckpt", - type=str, - default=None, - help="""The model checkpoint to initialize the model (either full or part). - If not specified, the model is randomly initialized. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.004, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--heat-loss-scale", - type=float, - default=0.2, - help="Scale for HEAT loss on separated sources.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=1, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 2000, - # parameters for SURT - "num_channels": 2, - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed - # parameters for Noam - "model_warm_step": 5000, # arg given to model, not for lrate - # parameters for ctc loss - "beam_size": 10, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def get_mask_encoder_model(params: AttributeDict) -> nn.Module: - mask_encoder = DPRNN( - feature_dim=params.feature_dim, - input_size=params.mask_encoder_dim, - hidden_size=params.mask_encoder_dim, - output_size=params.feature_dim * params.num_channels, - segment_size=params.mask_encoder_segment_size, - num_blocks=params.num_mask_encoder_layers, - chunk_width_randomization=params.chunk_width_randomization, - ) - return mask_encoder - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_joint_encoder_layer(params: AttributeDict) -> nn.Module: - class TakeFirst(nn.Module): - def forward(self, x): - return x[0] - - if params.use_joint_encoder_layer == "linear": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - nn.Linear( - params.num_channels * encoder_dim, params.num_channels * encoder_dim - ), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "lstm": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - ScaledLSTM( - input_size=params.num_channels * encoder_dim, - hidden_size=params.num_channels * encoder_dim, - num_layers=1, - bias=True, - batch_first=True, - dropout=0.0, - bidirectional=False, - ), - TakeFirst(), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "none": - joint_layer = None - else: - raise ValueError( - f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}" - ) - return joint_layer - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_surt_model( - params: AttributeDict, -) -> nn.Module: - mask_encoder = get_mask_encoder_model(params) - encoder = get_encoder_model(params) - joint_layer = get_joint_encoder_layer(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = SURT( - mask_encoder=mask_encoder, - encoder=encoder, - joint_encoder_layer=joint_layer, - decoder=decoder, - joiner=joiner, - num_channels=params.num_channels, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_heat_loss(x_masked, batch, num_channels=2) -> Tensor: - """ - Compute HEAT loss for separated sources using the output of mask encoder. - Args: - x_masked: - The output of mask encoder. It is a tensor of shape (B, T, C). - batch: - A batch of data. See `lhotse.dataset.K2SurtDatasetWithSources()` - for the content in it. - num_channels: - The number of output branches in the SURT model. - """ - B, T, D = x_masked[0].shape - device = x_masked[0].device - - # Create training targets for each channel. - targets = [] - for i in range(num_channels): - target = torch.ones_like(x_masked[i]) * LOG_EPSILON - targets.append(target) - - source_feats = batch["source_feats"] - source_boundaries = batch["source_boundaries"] - input_lens = batch["input_lens"].to(device) - # Assign sources to channels based on the HEAT criteria - for b in range(B): - cut_source_feats = source_feats[b] - cut_source_boundaries = source_boundaries[b] - last_seg_end = [0 for _ in range(num_channels)] - for source_feat, (start, end) in zip(cut_source_feats, cut_source_boundaries): - assigned = False - end = min(end, T) - source_feat = source_feat[: end - start, :] - for i in range(num_channels): - if start >= last_seg_end[i]: - targets[i][b, start:end, :] += source_feat.to(device) - last_seg_end[i] = max(end, last_seg_end[i]) - assigned = True - break - if not assigned: - min_end_channel = last_seg_end.index(min(last_seg_end)) - targets[min_end_channel][b, start:end, :] += source_feat.to(device) - last_seg_end[min_end_channel] = max(end, last_seg_end[min_end_channel]) - - # Get padding mask based on input lengths - pad_mask = torch.arange(T, device=device).expand(B, T) > input_lens.unsqueeze(1) - pad_mask = pad_mask.unsqueeze(-1) - - # Compute masked loss for each channel - losses = torch.zeros((num_channels, B, T, D), device=device) - for i in range(num_channels): - loss = nn.functional.mse_loss(x_masked[i], targets[i], reduction="none") - # Apply padding mask to loss - loss.masked_fill_(pad_mask, 0) - losses[i] = loss - - # loss: C x B x T x D. pad_mask: B x T x 1 - # We want to compute loss for each item in the batch. Each item has loss given - # by the sum over C, and average over T and D. For T, we need to use the padding. - loss = losses.sum(0).mean(-1).sum(-1) / batch["input_lens"].to(device) - return loss - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"].to(device) - feature_lens = batch["input_lens"].to(device) - - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - - # The dataloader returns text as a list of cuts, each of which is a list of channel - # text. We flatten this to a list where all channels are together, i.e., it looks like - # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. - text = [val for tup in zip(*batch["text"]) for val in tup] - assert len(text) == len(feature) * params.num_channels - - # Convert all channel texts to token IDs and create a ragged tensor. - y = sp.encode(text, out_type=int) - y = k2.RaggedTensor(y).to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.model_warm_step - - with torch.set_grad_enabled(is_training): - (simple_loss, pruned_loss, ctc_loss, x_masked) = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - reduction="none", - subsampling_factor=params.subsampling_factor, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - ctc_loss_is_finite = torch.isfinite(ctc_loss) - - # Compute HEAT loss - if is_training and params.heat_loss_scale > 0.0: - heat_loss = compute_heat_loss( - x_masked, batch, num_channels=params.num_channels - ) - else: - heat_loss = torch.tensor(0.0, device=device) - - heat_loss_is_finite = torch.isfinite(heat_loss) - is_finite = ( - simple_loss_is_finite - & pruned_loss_is_finite - & ctc_loss_is_finite - & heat_loss_is_finite - ) - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_losses: {simple_loss}\n" - f"pruned_losses: {pruned_loss}\n" - f"ctc_losses: {ctc_loss}\n" - f"heat_losses: {heat_loss}\n" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - ctc_loss = ctc_loss[ctc_loss_is_finite] - heat_loss = heat_loss[heat_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if ( - torch.all(~simple_loss_is_finite) - or torch.all(~pruned_loss_is_finite) - or torch.all(~ctc_loss_is_finite) - or torch.all(~heat_loss_is_finite) - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss_sum = simple_loss.sum() - pruned_loss_sum = pruned_loss.sum() - ctc_loss_sum = ctc_loss.sum() - heat_loss_sum = heat_loss.sum() - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = ( - simple_loss_scale * simple_loss_sum - + pruned_loss_scale * pruned_loss_sum - + params.ctc_loss_scale * ctc_loss_sum - + params.heat_loss_scale * heat_loss_sum - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss_sum.detach().cpu().item() - info["pruned_loss"] = pruned_loss_sum.detach().cpu().item() - if params.ctc_loss_scale > 0.0: - info["ctc_loss"] = ctc_loss_sum.detach().cpu().item() - if params.heat_loss_scale > 0.0: - info["heat_loss"] = heat_loss_sum.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - torch.cuda.empty_cache() - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = batch["inputs"].shape[0] - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - - if checkpoints is None and params.model_init_ckpt is not None: - logging.info( - f"Initializing model with checkpoint from {params.model_init_ckpt}" - ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) - model.load_state_dict(init_ckpt["model"], strict=False) - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - ami = AmiAsrDataModule(args) - - train_cuts = ami.aimix_train_cuts(rvb_affix="comb", sources=True) - dev_cuts = ami.ami_cuts(split="dev", type="ihm-mix") - dev_cuts = dev_cuts.trim_to_supervision_groups(max_pause=0.0).filter( - lambda c: 0.2 <= c.duration <= 60.0 - ) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = ami.train_dataloaders( - train_cuts, - sampler_state_dict=sampler_state_dict, - sources=True, - ) - valid_dl = ami.valid_dataloaders(dev_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = [sp.encode(text_ch) for text_ch in batch["text"]] - num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y] - logging.info(f"num tokens: {num_tokens}") - - -def main(): - parser = get_parser() - AmiAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - -if __name__ == "__main__": - main() diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py deleted file mode 100755 index ac5b0dadc..000000000 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ /dev/null @@ -1,1410 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -# ./dprnn_zipformer/train.py should be run before this script. - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./dprnn_zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir dprnn_zipformer/exp_adapt \ - --model-init-ckpt dprnn_zipformer/exp/epoch-30.pt \ - --max-duration 550 -""" - -import argparse -import copy -import logging -import warnings -from itertools import chain -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AmiAsrDataModule -from decoder import Decoder -from dprnn import DPRNN -from einops.layers.torch import Rearrange -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import LOG_EPSILON, fix_random_seed -from model import SURT -from optim import Eden, ScaledAdam -from scaling import ScaledLinear, ScaledLSTM -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-mask-encoder-layers", - type=int, - default=4, - help="Number of layers in the DPRNN based mask encoder.", - ) - - parser.add_argument( - "--mask-encoder-dim", - type=int, - default=256, - help="Hidden dimension of the LSTM blocks in DPRNN.", - ) - - parser.add_argument( - "--mask-encoder-segment-size", - type=int, - default=32, - help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " - "decode-chunk-length of the zipformer encoder.", - ) - - parser.add_argument( - "--chunk-width-randomization", - type=bool, - default=False, - help="Whether to randomize the chunk width in DPRNN.", - ) - - # Zipformer config is based on: - # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,2,2,2", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="768,768,768,768,768", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="256,256,256,256,256", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="192,192,192,192,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--use-joint-encoder-layer", - type=str, - default="linear", - choices=["linear", "lstm", "none"], - help="Whether to use a joint layer to combine all branches.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conv_lstm_transducer_stateless_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--model-init-ckpt", - type=str, - default=None, - help="""The model checkpoint to initialize the model (either full or part). - If not specified, the model is randomly initialized. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.0001, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=2, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=1, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 2000, - # parameters for SURT - "num_channels": 2, - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed - # parameters for Noam - "model_warm_step": 5000, # arg given to model, not for lrate - # parameters for ctc loss - "beam_size": 10, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def get_mask_encoder_model(params: AttributeDict) -> nn.Module: - mask_encoder = DPRNN( - feature_dim=params.feature_dim, - input_size=params.mask_encoder_dim, - hidden_size=params.mask_encoder_dim, - output_size=params.feature_dim * params.num_channels, - segment_size=params.mask_encoder_segment_size, - num_blocks=params.num_mask_encoder_layers, - chunk_width_randomization=params.chunk_width_randomization, - ) - return mask_encoder - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_joint_encoder_layer(params: AttributeDict) -> nn.Module: - class TakeFirst(nn.Module): - def forward(self, x): - return x[0] - - if params.use_joint_encoder_layer == "linear": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - nn.Linear( - params.num_channels * encoder_dim, params.num_channels * encoder_dim - ), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "lstm": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - ScaledLSTM( - input_size=params.num_channels * encoder_dim, - hidden_size=params.num_channels * encoder_dim, - num_layers=1, - bias=True, - batch_first=True, - dropout=0.0, - bidirectional=False, - ), - TakeFirst(), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "none": - joint_layer = None - else: - raise ValueError( - f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}" - ) - return joint_layer - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_surt_model( - params: AttributeDict, -) -> nn.Module: - mask_encoder = get_mask_encoder_model(params) - encoder = get_encoder_model(params) - joint_layer = get_joint_encoder_layer(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = SURT( - mask_encoder=mask_encoder, - encoder=encoder, - joint_encoder_layer=joint_layer, - decoder=decoder, - joiner=joiner, - num_channels=params.num_channels, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_heat_loss(x_masked, batch, num_channels=2) -> Tensor: - """ - Compute HEAT loss for separated sources using the output of mask encoder. - Args: - x_masked: - The output of mask encoder. It is a tensor of shape (B, T, C). - batch: - A batch of data. See `lhotse.dataset.K2SurtDatasetWithSources()` - for the content in it. - num_channels: - The number of output branches in the SURT model. - """ - B, T, D = x_masked[0].shape - device = x_masked[0].device - - # Create training targets for each channel. - targets = [] - for i in range(num_channels): - target = torch.ones_like(x_masked[i]) * LOG_EPSILON - targets.append(target) - - source_feats = batch["source_feats"] - source_boundaries = batch["source_boundaries"] - input_lens = batch["input_lens"].to(device) - # Assign sources to channels based on the HEAT criteria - for b in range(B): - cut_source_feats = source_feats[b] - cut_source_boundaries = source_boundaries[b] - last_seg_end = [0 for _ in range(num_channels)] - for source_feat, (start, end) in zip(cut_source_feats, cut_source_boundaries): - assigned = False - for i in range(num_channels): - if start >= last_seg_end[i]: - targets[i][b, start:end, :] += source_feat.to(device) - last_seg_end[i] = max(end, last_seg_end[i]) - assigned = True - break - if not assigned: - min_end_channel = last_seg_end.index(min(last_seg_end)) - targets[min_end_channel][b, start:end, :] += source_feat - last_seg_end[min_end_channel] = max(end, last_seg_end[min_end_channel]) - - # Get padding mask based on input lengths - pad_mask = torch.arange(T, device=device).expand(B, T) > input_lens.unsqueeze(1) - pad_mask = pad_mask.unsqueeze(-1) - - # Compute masked loss for each channel - losses = torch.zeros((num_channels, B, T, D), device=device) - for i in range(num_channels): - loss = nn.functional.mse_loss(x_masked[i], targets[i], reduction="none") - # Apply padding mask to loss - loss.masked_fill_(pad_mask, 0) - losses[i] = loss - - # loss: C x B x T x D. pad_mask: B x T x 1 - # We want to compute loss for each item in the batch. Each item has loss given - # by the sum over C, and average over T and D. For T, we need to use the padding. - loss = losses.sum(0).mean(-1).sum(-1) / batch["input_lens"].to(device) - return loss - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"].to(device) - feature_lens = batch["input_lens"].to(device) - - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - - # The dataloader returns text as a list of cuts, each of which is a list of channel - # text. We flatten this to a list where all channels are together, i.e., it looks like - # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. - text = [val for tup in zip(*batch["text"]) for val in tup] - assert len(text) == len(feature) * params.num_channels - - # Convert all channel texts to token IDs and create a ragged tensor. - y = sp.encode(text, out_type=int) - y = k2.RaggedTensor(y).to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.model_warm_step - - with torch.set_grad_enabled(is_training): - (simple_loss, pruned_loss, ctc_loss, x_masked) = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - reduction="none", - subsampling_factor=params.subsampling_factor, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - ctc_loss_is_finite = torch.isfinite(ctc_loss) - - # Compute HEAT loss - if is_training and params.heat_loss_scale > 0.0: - heat_loss = compute_heat_loss( - x_masked, batch, num_channels=params.num_channels - ) - else: - heat_loss = torch.tensor(0.0, device=device) - - heat_loss_is_finite = torch.isfinite(heat_loss) - is_finite = ( - simple_loss_is_finite - & pruned_loss_is_finite - & ctc_loss_is_finite - & heat_loss_is_finite - ) - if not torch.all(is_finite): - # logging.info( - # "Not all losses are finite!\n" - # f"simple_losses: {simple_loss}\n" - # f"pruned_losses: {pruned_loss}\n" - # f"ctc_losses: {ctc_loss}\n" - # f"heat_losses: {heat_loss}\n" - # ) - # display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - ctc_loss = ctc_loss[ctc_loss_is_finite] - heat_loss = heat_loss[heat_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if ( - torch.all(~simple_loss_is_finite) - or torch.all(~pruned_loss_is_finite) - or torch.all(~ctc_loss_is_finite) - or torch.all(~heat_loss_is_finite) - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss_sum = simple_loss.sum() - pruned_loss_sum = pruned_loss.sum() - ctc_loss_sum = ctc_loss.sum() - heat_loss_sum = heat_loss.sum() - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = ( - simple_loss_scale * simple_loss_sum - + pruned_loss_scale * pruned_loss_sum - + params.ctc_loss_scale * ctc_loss_sum - + params.heat_loss_scale * heat_loss_sum - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss_sum.detach().cpu().item() - info["pruned_loss"] = pruned_loss_sum.detach().cpu().item() - if params.ctc_loss_scale > 0.0: - info["ctc_loss"] = ctc_loss_sum.detach().cpu().item() - if params.heat_loss_scale > 0.0: - info["heat_loss"] = heat_loss_sum.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - torch.cuda.empty_cache() - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = batch["inputs"].shape[0] - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - - if checkpoints is None and params.model_init_ckpt is not None: - logging.info( - f"Initializing model with checkpoint from {params.model_init_ckpt}" - ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) - model.load_state_dict(init_ckpt["model"], strict=False) - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - ami = AmiAsrDataModule(args) - - train_cuts = ami.train_cuts() - train_cuts = train_cuts.filter(lambda c: 0.5 <= c.duration <= 35.0) - dev_cuts = ami.ami_cuts(split="dev", type="ihm-mix") - dev_cuts = dev_cuts.trim_to_supervision_groups(max_pause=0.0).filter( - lambda c: 0.2 <= c.duration <= 60.0 - ) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = ami.train_dataloaders( - train_cuts, - sampler_state_dict=sampler_state_dict, - ) - valid_dl = ami.valid_dataloaders(dev_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = [sp.encode(text_ch) for text_ch in batch["text"]] - num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y] - logging.info(f"num tokens: {num_tokens}") - - -def main(): - parser = get_parser() - AmiAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - -if __name__ == "__main__": - main() diff --git a/egs/ami/SURT/dprnn_zipformer/zipformer.py b/egs/ami/SURT/dprnn_zipformer/zipformer.py deleted file mode 120000 index 59b772024..000000000 --- a/egs/ami/SURT/dprnn_zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../libricss/SURT/dprnn_zipformer/zipformer.py \ No newline at end of file diff --git a/egs/ami/SURT/local/add_source_feats.py b/egs/ami/SURT/local/add_source_feats.py deleted file mode 100755 index 0917b88a6..000000000 --- a/egs/ami/SURT/local/add_source_feats.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file adds source features as temporal arrays to the mixture manifests. -It looks for manifests in the directory data/manifests. -""" -import logging -from pathlib import Path - -import numpy as np -from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy -from tqdm import tqdm - - -def add_source_feats(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - logging.info("Reading mixed cuts") - mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz") - mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz") - - logging.info("Reading source cuts") - source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz") - - logging.info("Adding source features to the mixed cuts") - pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features") - with CutSet.open_writer( - src_dir / "cuts_train_clean_sources.jsonl.gz" - ) as cut_writer_clean, CutSet.open_writer( - src_dir / "cuts_train_reverb_sources.jsonl.gz" - ) as cut_writer_reverb, LilcomChunkyWriter( - output_dir / "feats_train_clean_sources" - ) as source_feat_writer: - for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb): - assert cut_reverb.id == cut_clean.id + "_rvb" - source_feats = [] - source_feat_offsets = [] - cur_offset = 0 - for sup in sorted( - cut_clean.supervisions, key=lambda s: (s.start, s.speaker) - ): - source_cut = source_cuts[sup.id] - source_feats.append(source_cut.load_features()) - source_feat_offsets.append(cur_offset) - cur_offset += source_cut.num_frames - cut_clean.source_feats = source_feat_writer.store_array( - cut_clean.id, np.concatenate(source_feats, axis=0) - ) - cut_clean.source_feat_offsets = source_feat_offsets - cut_writer_clean.write(cut_clean) - # Also write the reverb cut - cut_reverb.source_feats = cut_clean.source_feats - cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets - cut_writer_reverb.write(cut_reverb) - pbar.update(1) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - add_source_feats() diff --git a/egs/ami/SURT/local/compute_fbank_aimix.py b/egs/ami/SURT/local/compute_fbank_aimix.py deleted file mode 100755 index 91b3a060b..000000000 --- a/egs/ami/SURT/local/compute_fbank_aimix.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the synthetically mixed AMI and ICSI -train set. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" -import logging -import random -import warnings -from pathlib import Path - -import torch -import torch.multiprocessing -import torchaudio -from lhotse import ( - AudioSource, - LilcomChunkyWriter, - Recording, - load_manifest, - load_manifest_lazy, -) -from lhotse.audio import set_ffmpeg_torchaudio_info_enabled -from lhotse.cut import MixedCut, MixTrack, MultiCut -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.utils import fix_random_seed, uuid4 -from tqdm import tqdm - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") -torchaudio.set_audio_backend("soundfile") -set_ffmpeg_torchaudio_info_enabled(False) - - -def compute_fbank_aimix(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz") - - # only uses RIRs and noises from REVERB challenge - real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter( - lambda r: "RVB2014" in r.id - ) - noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter( - lambda r: "RVB2014" in r.id - ) - - # Apply perturbation to the training cuts - logging.info("Applying perturbation to the training cuts") - train_cuts_rvb = train_cuts.map( - lambda c: augment( - c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True - ) - ) - - logging.info("Extracting fbank features for training cuts") - _ = train_cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / "ai-mix_feats_clean", - manifest_path=src_dir / "cuts_train_clean.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - _ = train_cuts_rvb.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / "ai-mix_feats_reverb", - manifest_path=src_dir / "cuts_train_reverb.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - -def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False): - """ - Given a mixed cut, this function optionally applies the following augmentations: - - Perturbing the SNRs of the tracks (in range [-5, 5] dB) - - Reverberation using a randomly selected RIR - - Adding noise - - Perturbing the loudness (in range [-20, -25] dB) - """ - out_cut = cut.drop_features() - - # Perturb the SNRs (optional) - if perturb_snr: - snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))] - for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)): - if i == 0: - # Skip the first track since it is the reference - continue - track.snr = snr - - # Reverberate the cut (optional) - if rirs is not None: - # Select an RIR at random - rir = random.choice(rirs) - # Select a channel at random - rir_channel = random.choice(list(range(rir.num_channels))) - # Reverberate the cut - out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel]) - - # Add noise (optional) - if noises is not None: - # Select a noise recording at random - noise = random.choice(noises).to_cut() - if isinstance(noise, MultiCut): - noise = noise.to_mono()[0] - # Select an SNR at random - snr = random.uniform(10, 30) - # Repeat the noise to match the duration of the cut - noise = repeat_cut(noise, out_cut.duration) - out_cut = MixedCut( - id=out_cut.id, - tracks=[ - MixTrack(cut=out_cut, type="MixedCut"), - MixTrack(cut=noise, type="DataCut", snr=snr), - ], - ) - - # Perturb the loudness (optional) - if perturb_loudness: - target_loudness = random.uniform(-20, -25) - out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True) - return out_cut - - -def repeat_cut(cut, duration): - while cut.duration < duration: - cut = cut.mix(cut, offset_other_by=cut.duration) - return cut.truncate(duration=duration) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - fix_random_seed(42) - compute_fbank_aimix() diff --git a/egs/ami/SURT/local/compute_fbank_ami.py b/egs/ami/SURT/local/compute_fbank_ami.py deleted file mode 100755 index 351b41765..000000000 --- a/egs/ami/SURT/local/compute_fbank_ami.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the AMI dataset. -We compute features for full recordings (i.e., without trimming to supervisions). -This way we can create arbitrary segmentations later. - -The generated fbank features are saved in data/fbank. -""" -import logging -import math -from pathlib import Path - -import torch -import torch.multiprocessing -from lhotse import CutSet, LilcomChunkyWriter -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_ami(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - manifests = {} - for part in ["ihm-mix", "sdm", "mdm8-bf"]: - manifests[part] = read_manifests_if_cached( - dataset_parts=["train", "dev", "test"], - output_dir=src_dir, - prefix=f"ami-{part}", - suffix="jsonl.gz", - ) - - for part in ["ihm-mix", "sdm", "mdm8-bf"]: - for split in ["train", "dev", "test"]: - logging.info(f"Processing {part} {split}") - cuts = CutSet.from_manifests( - **manifests[part][split] - ).compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"ami-{part}_{split}_feats", - manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_ami() diff --git a/egs/ami/SURT/local/compute_fbank_icsi.py b/egs/ami/SURT/local/compute_fbank_icsi.py deleted file mode 100755 index 4e2ff3f3b..000000000 --- a/egs/ami/SURT/local/compute_fbank_icsi.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the ICSI dataset. -We compute features for full recordings (i.e., without trimming to supervisions). -This way we can create arbitrary segmentations later. - -The generated fbank features are saved in data/fbank. -""" -import logging -import math -from pathlib import Path - -import torch -import torch.multiprocessing -from lhotse import CutSet, LilcomChunkyWriter -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_icsi(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - manifests = {} - for part in ["ihm-mix", "sdm"]: - manifests[part] = read_manifests_if_cached( - dataset_parts=["train"], - output_dir=src_dir, - prefix=f"icsi-{part}", - suffix="jsonl.gz", - ) - - for part in ["ihm-mix", "sdm"]: - for split in ["train"]: - logging.info(f"Processing {part} {split}") - cuts = CutSet.from_manifests( - **manifests[part][split] - ).compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"icsi-{part}_{split}_feats", - manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_icsi() diff --git a/egs/ami/SURT/local/compute_fbank_ihm.py b/egs/ami/SURT/local/compute_fbank_ihm.py deleted file mode 100755 index 56f54aa21..000000000 --- a/egs/ami/SURT/local/compute_fbank_ihm.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the trimmed sub-segments which will be -used for simulating the training mixtures. - -The generated fbank features are saved in data/fbank. -""" -import logging -import math -from pathlib import Path - -import torch -import torch.multiprocessing -import torchaudio -from lhotse import CutSet, LilcomChunkyWriter, load_manifest -from lhotse.audio import set_ffmpeg_torchaudio_info_enabled -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached -from tqdm import tqdm - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") -torchaudio.set_audio_backend("soundfile") -set_ffmpeg_torchaudio_info_enabled(False) - - -def compute_fbank_ihm(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - manifests = {} - for data in ["ami", "icsi"]: - manifests[data] = read_manifests_if_cached( - dataset_parts=["train"], - output_dir=src_dir, - types=["recordings", "supervisions"], - prefix=f"{data}-ihm", - suffix="jsonl.gz", - ) - - logging.info("Computing features") - for data in ["ami", "icsi"]: - cs = CutSet.from_manifests(**manifests[data]["train"]) - cs = cs.trim_to_supervisions(keep_overlapping=False) - cs = cs.normalize_loudness(target=-23.0, affix_id=False) - cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1) - _ = cs.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"{data}-ihm_train_feats", - manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_ihm() diff --git a/egs/ami/SURT/local/prepare_ami_train_cuts.py b/egs/ami/SURT/local/prepare_ami_train_cuts.py deleted file mode 100755 index 72fced70d..000000000 --- a/egs/ami/SURT/local/prepare_ami_train_cuts.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file creates AMI train segments. -""" -import logging -import math -from pathlib import Path - -import torch -import torch.multiprocessing -from lhotse import LilcomChunkyWriter, load_manifest_lazy -from lhotse.cut import Cut, CutSet -from lhotse.utils import EPSILON, add_durations -from tqdm import tqdm - - -def cut_into_windows(cuts: CutSet, duration: float): - """ - This function takes a CutSet and cuts each cut into windows of roughly - `duration` seconds. By roughly, we mean that we try to adjust for the last supervision - that exceeds the duration, or is shorter than the duration. - """ - res = [] - with tqdm() as pbar: - for cut in cuts: - pbar.update(1) - sups = cut.index_supervisions()[cut.id] - sr = cut.sampling_rate - start = 0.0 - end = duration - num_tries = 0 - while start < cut.duration and num_tries < 2: - # Find the supervision that are cut by the window endpoint - hitlist = [iv for iv in sups.at(end) if iv.begin < end] - # If there are no supervisions, we are done - if not hitlist: - res.append( - cut.truncate( - offset=start, - duration=add_durations(end, -start, sampling_rate=sr), - keep_excessive_supervisions=False, - ) - ) - # Update the start and end for the next window - start = end - end = add_durations(end, duration, sampling_rate=sr) - else: - # find ratio of durations cut by the window endpoint - ratios = [ - add_durations(end, -iv.end, sampling_rate=sr) / iv.length() - for iv in hitlist - ] - # we retain the supervisions that have >50% of their duration - # in the window, and discard the others - retained = [] - discarded = [] - for iv, ratio in zip(hitlist, ratios): - if ratio > 0.5: - retained.append(iv) - else: - discarded.append(iv) - cur_end = max(iv.end for iv in retained) if retained else end - res.append( - cut.truncate( - offset=start, - duration=add_durations(cur_end, -start, sampling_rate=sr), - keep_excessive_supervisions=False, - ) - ) - # For the next window, we start at the earliest discarded supervision - next_start = min(iv.begin for iv in discarded) if discarded else end - next_end = add_durations(next_start, duration, sampling_rate=sr) - # It may happen that next_start is the same as start, in which case - # we will advance the window anyway - if next_start == start: - logging.warning( - f"Next start is the same as start: {next_start} == {start} for cut {cut.id}" - ) - start = end + EPSILON - end = add_durations(start, duration, sampling_rate=sr) - num_tries += 1 - else: - start = next_start - end = next_end - return CutSet.from_cuts(res) - - -def prepare_train_cuts(): - src_dir = Path("data/manifests") - - logging.info("Loading the manifests") - train_cuts_ihm = load_manifest_lazy( - src_dir / "cuts_ami-ihm-mix_train.jsonl.gz" - ).map(lambda c: c.with_id(f"{c.id}_ihm-mix")) - train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map( - lambda c: c.with_id(f"{c.id}_sdm") - ) - train_cuts_mdm = load_manifest_lazy( - src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz" - ).map(lambda c: c.with_id(f"{c.id}_mdm8-bf")) - - # Combine all cuts into one CutSet - train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm - - train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5) - train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0) - - # Combine the two segmentations - train_all = train_cuts_1 + train_cuts_2 - - # At this point, some of the cuts may be very long. We will cut them into windows of - # roughly 30 seconds. - logging.info("Cutting the segments into windows of 30 seconds") - train_all_30 = cut_into_windows(train_all, duration=30.0) - logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}") - - # Show statistics - train_all.describe(full=True) - - # Save the cuts - logging.info("Saving the cuts") - train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - prepare_train_cuts() diff --git a/egs/ami/SURT/local/prepare_icsi_train_cuts.py b/egs/ami/SURT/local/prepare_icsi_train_cuts.py deleted file mode 100755 index 818e26bfb..000000000 --- a/egs/ami/SURT/local/prepare_icsi_train_cuts.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file creates ICSI train segments. -""" -import logging -from pathlib import Path - -from lhotse import load_manifest_lazy -from prepare_ami_train_cuts import cut_into_windows - - -def prepare_train_cuts(): - src_dir = Path("data/manifests") - - logging.info("Loading the manifests") - train_cuts_ihm = load_manifest_lazy( - src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz" - ).map(lambda c: c.with_id(f"{c.id}_ihm-mix")) - train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map( - lambda c: c.with_id(f"{c.id}_sdm") - ) - - # Combine all cuts into one CutSet - train_cuts = train_cuts_ihm + train_cuts_sdm - - train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5) - train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0) - - # Combine the two segmentations - train_all = train_cuts_1 + train_cuts_2 - - # At this point, some of the cuts may be very long. We will cut them into windows of - # roughly 30 seconds. - logging.info("Cutting the segments into windows of 30 seconds") - train_all_30 = cut_into_windows(train_all, duration=30.0) - logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}") - - # Show statistics - train_all.describe(full=True) - - # Save the cuts - logging.info("Saving the cuts") - train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - prepare_train_cuts() diff --git a/egs/ami/SURT/local/prepare_lang_bpe.py b/egs/ami/SURT/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/ami/SURT/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/ami/SURT/local/train_bpe_model.py b/egs/ami/SURT/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/ami/SURT/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ami/SURT/prepare.sh b/egs/ami/SURT/prepare.sh deleted file mode 100755 index ea4e5baf2..000000000 --- a/egs/ami/SURT/prepare.sh +++ /dev/null @@ -1,195 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/ami -# You can find audio and transcripts for AMI in this path. -# -# - $dl_dir/icsi -# You can find audio and transcripts for ICSI in this path. -# -# - $dl_dir/rirs_noises -# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/. -# -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data -vocab_size=500 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/amicorpus, - # you can create a symlink - # - # ln -sfv /path/to/amicorpus $dl_dir/amicorpus - # - if [ ! -d $dl_dir/amicorpus ]; then - for mic in ihm ihm-mix sdm mdm8-bf; do - lhotse download ami --mic $mic $dl_dir/amicorpus - done - fi - - # If you have pre-downloaded it to /path/to/icsi, - # you can create a symlink - # - # ln -sfv /path/to/icsi $dl_dir/icsi - # - if [ ! -d $dl_dir/icsi ]; then - lhotse download icsi $dl_dir/icsi - fi - - # If you have pre-downloaded it to /path/to/rirs_noises, - # you can create a symlink - # - # ln -sfv /path/to/rirs_noises $dl_dir/ - # - if [ ! -d $dl_dir/rirs_noises ]; then - lhotse download rirs_noises $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare AMI manifests" - # We assume that you have downloaded the AMI corpus - # to $dl_dir/amicorpus. We perform text normalization for the transcripts. - mkdir -p data/manifests - for mic in ihm ihm-mix sdm mdm8-bf; do - log "Preparing AMI manifest for $mic" - lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/ - done -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare ICSI manifests" - # We assume that you have downloaded the ICSI corpus - # to $dl_dir/icsi. We perform text normalization for the transcripts. - mkdir -p data/manifests - log "Preparing ICSI manifest" - for mic in ihm ihm-mix sdm; do - lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/ - done -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare RIRs" - # We assume that you have downloaded the RIRS_NOISES corpus - # to $dl_dir/rirs_noises - lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 3: Extract features for AMI and ICSI recordings" - python local/compute_fbank_ami.py - python local/compute_fbank_icsi.py -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Create sources for simulating mixtures" - # In the following script, we speed-perturb the IHM recordings and extract features. - python local/compute_fbank_ihm.py - lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \ - data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\ - lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\ - lhotse filter 'duration<=12.0' - - |\ - shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Create training mixtures" - lhotse workflows simulate-meetings \ - --method conversational \ - --same-spk-pause 0.5 \ - --diff-spk-pause 0.5 \ - --diff-spk-overlap 1.0 \ - --prob-diff-spk-overlap 0.8 \ - --num-meetings 200000 \ - --num-speakers-per-meeting 2,3 \ - --max-duration-per-speaker 15.0 \ - --max-utterances-per-speaker 3 \ - --seed 1234 \ - --num-jobs 2 \ - data/manifests/ihm_cuts_train_trimmed.jsonl.gz \ - data/manifests/ai-mix_cuts_clean.jsonl.gz - - python local/compute_fbank_aimix.py - - # Add source features to the manifest (will be used for masking loss) - # This may take ~2 hours. - python local/add_source_feats.py - - # Combine clean and reverb - cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \ - <(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\ - shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Create training mixtures from real sessions" - python local/prepare_ami_train_cuts.py - python local/prepare_icsi_train_cuts.py - - # Combine AMI and ICSI - cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \ - <(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\ - shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)." - mkdir -p data/lm - cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \ - <(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \ - > data/lm/transcript_words.txt -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)" - - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - - # Add special words to words.txt - echo " 0" > $lang_dir/words.txt - echo "!SIL 1" >> $lang_dir/words.txt - echo " 2" >> $lang_dir/words.txt - - # Add regular words to words.txt - cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt - - # Add remaining special word symbols expected by LM scripts. - num_words=$(cat $lang_dir/words.txt | wc -l) - echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(cat $lang_dir/words.txt | wc -l) - echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(cat $lang_dir/words.txt | wc -l) - echo "#0 ${num_words}" >> $lang_dir/words.txt - - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript data/lm/transcript_words.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - fi -fi diff --git a/egs/ami/SURT/shared b/egs/ami/SURT/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/ami/SURT/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/audioset/AT/README.md b/egs/audioset/AT/README.md deleted file mode 100644 index 2506d41e5..000000000 --- a/egs/audioset/AT/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# Introduction - -This is an audio tagging recipe for [Audioset](https://research.google.com/audioset/#/). It aims at predicting the sound events of an audio clip. - -[./RESULTS.md](./RESULTS.md) contains the latest results. - - -# Zipformer - -| Encoder | Feature type | -| --------| -------------| -| Zipformer | Frame level fbank| diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md deleted file mode 100644 index 36613db03..000000000 --- a/egs/audioset/AT/RESULTS.md +++ /dev/null @@ -1,119 +0,0 @@ -## Results - -### zipformer -See for more details - -[zipformer](./zipformer) - -#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M - -You can find a pretrained model, training logs, decoding logs, and decoding results at: - - -The model achieves the following mean averaged precision on AudioSet: - -| Model | mAP | -| ------ | ------- | -| Zipformer-AT | 45.1 | - -The training command is: - -```bash -export CUDA_VISIBLE_DEVICES="4,5,6,7" -subset=full - -python zipformer/train.py \ - --world-size 4 \ - --num-epochs 50 \ - --exp-dir zipformer/exp_at_as_${subset} \ - --start-epoch 1 \ - --use-fp16 1 \ - --num-events 527 \ - --audioset-subset $subset \ - --max-duration 1000 \ - --enable-musan True \ - --master-port 13455 -``` - -We recommend that you train the model with weighted sampler, as the model converges -faster with better performance: - -| Model | mAP | -| ------ | ------- | -| Zipformer-AT, train with weighted sampler | 46.6 | - -The evaluation command is: - -```bash -export CUDA_VISIBLE_DEVICES="4,5,6,7" -subset=full -weighted_sampler=1 -bucket_sampler=0 -lr_epochs=15 - -python zipformer/train.py \ - --world-size 4 \ - --audioset-subset $subset \ - --num-epochs 120 \ - --start-epoch 1 \ - --use-fp16 1 \ - --num-events 527 \ - --lr-epochs $lr_epochs \ - --exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \ - --weighted-sampler $weighted_sampler \ - --bucketing-sampler $bucket_sampler \ - --max-duration 1000 \ - --enable-musan True \ - --master-port 13452 -``` - -The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler - - -#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M - -You can find a pretrained model, training logs, decoding logs, and decoding results at: - - -The model achieves the following mean averaged precision on AudioSet: - -| Model | mAP | -| ------ | ------- | -| Zipformer-S-AT | 45.1 | - -The training command is: - -```bash -export CUDA_VISIBLE_DEVICES="4,5,6,7" -subset=full - -python zipformer/train.py \ - --world-size 4 \ - --num-epochs 50 \ - --exp-dir zipformer/exp_small_at_as_${subset} \ - --start-epoch 1 \ - --use-fp16 1 \ - --num-events 527 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,768,768,768,768 \ - --encoder-dim 192,256,256,256,256,256 \ - --encoder-unmasked-dim 192,192,192,192,192,192 \ - --audioset-subset $subset \ - --max-duration 1200 \ - --enable-musan True \ - --master-port 13455 -``` - -The evaluation command is: - -```bash -python zipformer/evaluate.py \ - --epoch 31 \ - --avg 4 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,768,768,768,768 \ - --encoder-dim 192,256,256,256,256,256 \ - --encoder-unmasked-dim 192,192,192,192,192,192 \ - --exp-dir zipformer/exp_small_at_as_full \ - --max-duration 500 -``` diff --git a/egs/audioset/AT/local/compute_fbank_musan.py b/egs/audioset/AT/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/audioset/AT/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/audioset/AT/local/compute_weight.py b/egs/audioset/AT/local/compute_weight.py deleted file mode 100644 index a0deddc0c..000000000 --- a/egs/audioset/AT/local/compute_weight.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file generates the manifest and computes the fbank features for AudioSet -dataset. The generated manifests and features are stored in data/fbank. -""" - -import argparse - -import lhotse -from lhotse import load_manifest - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz" - ) - - parser.add_argument( - "--output", - type=str, - required=True, - ) - return parser - - -def main(): - # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py - parser = get_parser() - args = parser.parse_args() - - cuts = load_manifest(args.input_manifest) - - print(f"A total of {len(cuts)} cuts.") - - label_count = [0] * 527 # a total of 527 classes - for c in cuts: - audio_event = c.supervisions[0].audio_event - labels = list(map(int, audio_event.split(";"))) - for label in labels: - label_count[label] += 1 - - with open(args.output, "w") as f: - for c in cuts: - audio_event = c.supervisions[0].audio_event - labels = list(map(int, audio_event.split(";"))) - weight = 0 - for label in labels: - weight += 1000 / (label_count[label] + 0.01) - f.write(f"{c.id} {weight}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/audioset/AT/local/generate_audioset_manifest.py b/egs/audioset/AT/local/generate_audioset_manifest.py deleted file mode 100644 index 1c5b3457c..000000000 --- a/egs/audioset/AT/local/generate_audioset_manifest.py +++ /dev/null @@ -1,177 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file generates the manifest and computes the fbank features for AudioSet -dataset. The generated manifests and features are stored in data/fbank. -""" - -import argparse -import csv -import glob -import logging -import os -from typing import Dict - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.audio import Recording -from lhotse.cut import MonoCut -from lhotse.supervision import SupervisionSegment - -from icefall.utils import get_executor - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_ID_mapping(csv_file): - # get a mapping between class ID and class name - mapping = {} - with open(csv_file, "r") as fin: - reader = csv.reader(fin, delimiter=",") - for i, row in enumerate(reader): - if i == 0: - continue - mapping[row[1]] = row[0] - return mapping - - -def parse_csv(csv_file: str, id_mapping: Dict): - # The content of the csv file shoud be something like this - # ------------------------------------------------------ - # filename label - # dataset/AudioSet/balanced/xxxx.wav 0;451 - # dataset/AudioSet/balanced/xxxy.wav 375 - # ------------------------------------------------------ - - def name2id(names): - ids = [id_mapping[name] for name in names.split(",")] - return ";".join(ids) - - mapping = {} - with open(csv_file, "r") as fin: - reader = csv.reader(fin, delimiter=" ") - for i, row in enumerate(reader): - if i <= 2: - continue - key = row[0].replace(",", "") - mapping[key] = name2id(row[-1]) - return mapping - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument("--dataset-dir", type=str, default="downloads/audioset") - - parser.add_argument( - "--split", - type=str, - default="balanced", - choices=["balanced", "unbalanced", "eval"], - ) - - parser.add_argument( - "--feat-output-dir", - type=str, - default="data/fbank", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - dataset_dir = args.dataset_dir - split = args.split - feat_output_dir = args.feat_output_dir - - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - if split in ["balanced", "unbalanced"]: - csv_file = f"{dataset_dir}/{split}_train_segments.csv" - elif split == "eval": - csv_file = f"{dataset_dir}/eval_segments.csv" - else: - raise ValueError() - - class_indices_csv = f"{dataset_dir}/class_labels_indices.csv" - id_mapping = get_ID_mapping(class_indices_csv) - labels = parse_csv(csv_file, id_mapping) - - audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav") - - new_cuts = [] - for i, audio in enumerate(audio_files): - cut_id = audio.split("/")[-1].split("_")[0] - recording = Recording.from_file(audio, cut_id) - cut = MonoCut( - id=cut_id, - start=0.0, - duration=recording.duration, - channel=0, - recording=recording, - ) - supervision = SupervisionSegment( - id=cut_id, - recording_id=cut.recording.id, - start=0.0, - channel=0, - duration=cut.duration, - ) - try: - supervision.audio_event = labels[cut_id] - except KeyError: - logging.info(f"No labels found for {cut_id}.") - continue - cut.supervisions = [supervision] - new_cuts.append(cut) - - if i % 100 == 0 and i: - logging.info(f"Processed {i} cuts until now.") - - cuts = CutSet.from_cuts(new_cuts) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - logging.info(f"Computing fbank features for {split}") - with get_executor() as ex: - cuts = cuts.compute_and_store_features( - extractor=extractor, - storage_path=f"{feat_output_dir}/{split}_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz" - - logging.info(f"Storing the manifest to {manifest_output_dir}") - cuts.to_jsonl(manifest_output_dir) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/audioset/AT/prepare.sh b/egs/audioset/AT/prepare.sh deleted file mode 100755 index 8beaf2d86..000000000 --- a/egs/audioset/AT/prepare.sh +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -# run step 0 to step 5 by default -stage=-1 -stop_stage=4 - -dl_dir=$PWD/download -fbank_dir=data/fbank - -# we assume that you have your downloaded the AudioSet and placed -# it under $dl_dir/audioset, the folder structure should look like -# this: -# - $dl_dir/audioset -# - balanced -# - eval -# - unbalanced -# If you haven't downloaded the AudioSet, please refer to -# https://github.com/RicherMans/SAT/blob/main/datasets/audioset/1_download_audioset.sh. - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "Running prepare.sh" - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage 0: Download the necessary csv files" - if [ ! -e $dl_dir/audioset/.csv.done]; then - wget --continue "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv" -O "${dl_dir}/audioset/class_labels_indices.csv" - wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv -O "${dl_dir}/audioset/balanced_train_segments.csv" - wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/eval_segments.csv -O "${dl_dir}/audioset/eval_segments.csv" - touch $dl_dir/audioset/.csv.done - fi -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" - if [! -e $fbank_dir/.balanced.done]; then - python local/generate_audioset_manifest.py \ - --dataset-dir $dl_dir/audioset \ - --split balanced \ - --feat-output-dir $fbank_dir - touch $fbank_dir/.balanced.done - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Construct the audioset manifest and compute the fbank features for unbalanced set" - fbank_dir=data/fbank - if [! -e $fbank_dir/.unbalanced.done]; then - python local/generate_audioset_manifest.py \ - --dataset-dir $dl_dir/audioset \ - --split unbalanced \ - --feat-output-dir $fbank_dir - touch $fbank_dir/.unbalanced.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Construct the audioset manifest and compute the fbank features for eval set" - fbank_dir=data/fbank - if [! -e $fbank_dir/.eval.done]; then - python local/generate_audioset_manifest.py \ - --dataset-dir $dl_dir/audioset \ - --split eval \ - --feat-output-dir $fbank_dir - touch $fbank_dir/.eval.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - if [ ! -e data/manifests/.musan.done ]; then - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - mkdir -p data/fbank - if [ ! -e data/fbank/.musan.done ]; then - ./local/compute_fbank_musan.py - touch data/fbank/.musan.done - fi -fi - -# The following stages are required to do weighted-sampling training -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare for weighted-sampling training" - if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then - lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz - fi - python ./local/compute_weight.py \ - --input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \ - --output $fbank_dir/sampling_weights_full.txt -fi diff --git a/egs/audioset/AT/shared b/egs/audioset/AT/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/audioset/AT/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py deleted file mode 100644 index b7df01539..000000000 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - AudioTaggingDataset, - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, - WeightedSimpleCutSampler, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AudioSetATDatamodule: - """ - DataModule for k2 audio tagging (AT) experiments. - - - It contains all the common data pipeline modules used in AT - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in AT tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="AT data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--audioset-subset", - type=str, - default="balanced", - choices=["balanced", "full"], - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with audioset train/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--weighted-sampler", - type=str2bool, - default=False, - help="When enabled, samples are drawn from by their weights. " - "It cannot be used together with bucketing sampler", - ) - group.add_argument( - "--num-samples", - type=int, - default=200000, - help="The number of samples to be drawn in each epoch. Only be used" - "for weighed sampler", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ): - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = AudioTaggingDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = AudioTaggingDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - assert ( - not self.args.weighted_sampler - ), "weighted sampling is not supported in bucket sampler" - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - if self.args.weighted_sampler: - # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset" - logging.info("Using weighted SimpleCutSampler") - weights = self.audioset_sampling_weights() - train_sampler = WeightedSimpleCutSampler( - cuts_train, - weights, - num_samples=self.args.num_samples, - max_duration=self.args.max_duration, - shuffle=False, # do not support shuffle - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - drop_last=self.args.drop_last, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = AudioTaggingDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = AudioTaggingDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = AudioTaggingDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def audioset_train_cuts(self) -> CutSet: - logging.info("About to get the audioset training cuts.") - if not self.args.weighted_sampler: - balanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" - ) - if self.args.audioset_subset == "full": - unbalanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" - ) - cuts = CutSet.mux( - balanced_cuts, - unbalanced_cuts, - weights=[20000, 2000000], - stop_early=True, - ) - else: - cuts = balanced_cuts - else: - # assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet" - cuts = load_manifest( - self.args.manifest_dir - / f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz" - ) - logging.info(f"Get {len(cuts)} cuts in total.") - - return cuts - - @lru_cache() - def audioset_eval_cuts(self) -> CutSet: - logging.info("About to get audioset eval cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" - ) - - @lru_cache() - def audioset_sampling_weights(self): - logging.info( - f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" - ) - weights = [] - with open( - self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt", - "r", - ) as f: - while True: - line = f.readline() - if not line: - break - weight = float(line.split()[1]) - weights.append(weight) - logging.info(f"Get the sampling weight for {len(weights)} cuts") - return weights diff --git a/egs/audioset/AT/zipformer/encoder_interface.py b/egs/audioset/AT/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/audioset/AT/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py deleted file mode 100644 index 0a1b8ea5f..000000000 --- a/egs/audioset/AT/zipformer/evaluate.py +++ /dev/null @@ -1,327 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0" - -./zipformer/evaluate.py \ - --epoch 50 \ - --avg 10 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - - -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict - -import torch -import torch.nn as nn -from at_datamodule import AudioSetATDatamodule - -try: - from sklearn.metrics import average_precision_score -except: - raise ImportError(f"Please run\n" "pip3 install -U scikit-learn") -from train import add_model_arguments, get_model, get_params, str2multihot - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import AttributeDict, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - add_model_arguments(parser) - - return parser - - -def inference_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, -): - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3, feature.shape - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - audio_event = supervisions["audio_event"] - - label, _ = str2multihot(audio_event) - label = label.detach().cpu() - - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) - # convert to probabilities between 0-1 - audio_logits = audio_logits.sigmoid().detach().cpu() - - return audio_logits, label - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, -) -> Dict: - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - all_logits = [] - all_labels = [] - - for batch_idx, batch in enumerate(dl): - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - num_cuts += len(cut_ids) - - audio_logits, labels = inference_one_batch( - params=params, - model=model, - batch=batch, - ) - - all_logits.append(audio_logits) - all_labels.append(labels) - - if batch_idx % 20 == 1: - logging.info(f"Processed {num_cuts} cuts already.") - logging.info("Finish collecting audio logits") - - return all_logits, all_labels - - -@torch.no_grad() -def main(): - parser = get_parser() - AudioSetATDatamodule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "inference_audio_tagging" - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Evaluation started") - - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info("About to create model") - - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - args.return_cuts = True - audioset = AudioSetATDatamodule(args) - - audioset_cuts = audioset.audioset_eval_cuts() - - audioset_dl = audioset.valid_dataloaders(audioset_cuts) - - test_sets = ["audioset_eval"] - - logits, labels = decode_dataset( - dl=audioset_dl, - params=params, - model=model, - ) - - logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy() - labels = torch.cat(labels, dim=0).long().detach().numpy() - - # compute the metric - mAP = average_precision_score( - y_true=labels, - y_score=logits, - ) - - logging.info(f"mAP for audioset eval is: {mAP}") - - logging.info("Done") - - -if __name__ == "__main__": - main() diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py deleted file mode 100755 index 2b0ec8b4b..000000000 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ /dev/null @@ -1,411 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script exports a transducer model from PyTorch to ONNX. - -Usage of this script: - - repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 - repo=$(basename $repo_url) - GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url - pushd $repo/exp - git lfs pull --include pretrained.pt - ln -s pretrained.pt epoch-99.pt - popd - - python3 zipformer/export-onnx.py \ - --exp-dir $repo/exp \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 - - pushd $repo/exp - mv model-epoch-99-avg-1.onnx model.onnx - mv model-epoch-99-avg-1.int8.onnx model.int8.onnx - popd - -See ./onnx_pretrained.py -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict - -import onnx -import onnxoptimizer -import torch -import torch.nn as nn -from onnxruntime.quantization import QuantType, quantize_dynamic -from onnxsim import simplify -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxAudioTagger(nn.Module): - """A wrapper for Zipformer audio tagger""" - - def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, classifier: nn.Linear - ): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.classifier = classifier - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> torch.Tensor: - """Please see the help information of Zipformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tensor containing: - - probs, A 2-D tensor of shape (N, num_classes) - - """ - x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C) - - logits = self.classifier(encoder_out) # (N, T, num_classes) - # Note that this is slightly different from model.py for better - # support of onnx - logits = logits.mean(dim=1) - probs = logits.sigmoid() - return probs - - -def export_audio_tagging_model_onnx( - model: OnnxAudioTagger, - filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - model: - The input encoder model - filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 200, 80, dtype=torch.float32) - x_lens = torch.tensor([200], dtype=torch.int64) - - model = torch.jit.trace(model, (x, x_lens)) - - torch.onnx.export( - model, - (x, x_lens), - filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["logits"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "probs": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "zipformer2", - "version": "1", - "model_author": "k2-fsa", - "comment": "zipformer2 audio tagger", - "url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=filename, meta_data=meta_data) - - -def optimize_model(filename): - # see - # https://github.com/microsoft/onnxruntime/issues/1899#issuecomment-534806537 - # and - # https://github.com/onnx/onnx/issues/582#issuecomment-937788108 - # and - # https://github.com/onnx/optimizer/issues/110 - # and - # https://qiita.com/Yossy_Hal/items/34f3b2aef2199baf7f5f - passes = ["eliminate_unused_initializer"] - onnx_model = onnx.load(filename) - onnx_model = onnxoptimizer.optimize(onnx_model, passes) - - model_simp, check = simplify(onnx_model) - if check: - logging.info("Simplified the model!") - onnx_model = model_simp - else: - logging.info("Failed to simplify the model!") - - onnx.save(onnx_model, filename) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - model = OnnxAudioTagger( - encoder=model.encoder, - encoder_embed=model.encoder_embed, - classifier=model.classifier, - ) - - model_num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"total parameters: {model_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting audio tagging model") - model_filename = params.exp_dir / f"model-{suffix}.onnx" - export_audio_tagging_model_onnx( - model, - model_filename, - opset_version=opset_version, - ) - optimize_model(model_filename) - logging.info(f"Exported audio tagging model to {model_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - model_filename_int8 = params.exp_dir / f"model-{suffix}.int8.onnx" - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - optimize_model(model_filename_int8) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py deleted file mode 100755 index 6ceeca8de..000000000 --- a/egs/audioset/AT/zipformer/export.py +++ /dev/null @@ -1,340 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -Note: This is an example for AudioSet dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -(1) Export to torchscript model using torch.jit.script() - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("jit_script.pt")`. - -Check ./jit_pretrained.py for its usage. - -Check https://github.com/k2-fsa/sherpa -and https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --epoch 30 \ - --avg 9 - - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `zipformer/evaluate.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/audioset/AT - ./zipformer/evaluate.py \ - --exp-dir ./zipformer/exp \ - --use-averaged-model False \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 - -Check ./pretrained.py for its usage. - -""" - -import argparse -import logging -from pathlib import Path -from typing import Tuple - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor, nn -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named jit_script.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - add_model_arguments(parser) - - return parser - - -class EncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Args: - features: (N, T, C) - feature_lengths: (N,) - """ - x, x_lens = self.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -class Classifier(nn.Module): - """A wrapper for audio tagging classifier""" - - def __init__(self, classifier: nn.Module) -> None: - super().__init__() - self.classifier = classifier - - def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor): - """ - Args: - encoder_out: - A 3-D tensor of shape (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - """ - logits = self.classifier(encoder_out) # (N, T, num_classes) - padding_mask = make_pad_mask(encoder_out_lens) - logits[padding_mask] = 0 - logits = logits.sum(dim=1) # mask the padding frames - logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( - logits - ) # normalize the logits - - return logits - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - - logging.info(f"device: {device}") - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - model.classifier = Classifier(model.classifier) - filename = "jit_script.pt" - - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - model.save(str(params.exp_dir / filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/audioset/AT/zipformer/jit_pretrained.py b/egs/audioset/AT/zipformer/jit_pretrained.py deleted file mode 100755 index d376aa148..000000000 --- a/egs/audioset/AT/zipformer/jit_pretrained.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# 2024 Xiaoyu Yang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - - repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 - repo=$(basename $repo_url) - GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url - pushd $repo/exp - git lfs pull --include jit_script.pt - popd - - python3 zipformer/jit_pretrained.py \ - --nn-model-filename $repo/exp/jit_script.pt \ - --label-dict $repo/data/class_labels_indices.csv \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav \ - $repo/test_wavs/3.wav \ - $repo/test_wavs/4.wav -""" - -import argparse -import csv -import logging -import math -from typing import List - -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--label-dict", - type=str, - help="""class_labels_indices.csv.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - # get the label dictionary - label_dict = {} - with open(args.label_dict, "r") as f: - reader = csv.reader(f, delimiter=",") - for i, row in enumerate(reader): - if i == 0: - continue - label_dict[int(row[0])] = row[2] - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - features=features, - feature_lengths=feature_lengths, - ) - - logits = model.classifier(encoder_out, encoder_out_lens) - - for filename, logit in zip(args.sound_files, logits): - topk_prob, topk_index = logit.sigmoid().topk(5) - topk_labels = [label_dict[index.item()] for index in topk_index] - logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with " - f"probability of {topk_prob.tolist()}" - ) - - logging.info("Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/audioset/AT/zipformer/model.py b/egs/audioset/AT/zipformer/model.py deleted file mode 100644 index fb8e2dd85..000000000 --- a/egs/audioset/AT/zipformer/model.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2021-2023 Xiaomi Corp. (authors: Xiaoyu Yang, -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface - -from icefall.utils import make_pad_mask - - -class AudioTaggingModel(nn.Module): - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - encoder_dim: int = 384, - num_events: int = 527, - ): - """An audio tagging model - - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dim) and - `logit_lens` of shape (N,). - encoder_dim: - Dimension of the encoder. - num_event: - The number of classes. - """ - super().__init__() - - assert isinstance(encoder, EncoderInterface), type(encoder) - - self.encoder_embed = encoder_embed - self.encoder = encoder - self.encoder_dim = encoder_dim - - self.classifier = nn.Sequential( - nn.Dropout(0.1), - nn.Linear(encoder_dim, num_events), - ) - - # for multi-class classification - self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum") - - def forward_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute encoder outputs. - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - - Returns: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - """ - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - x, x_lens = self.encoder_embed(x, x_lens) - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - - return encoder_out, encoder_out_lens - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - target: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - target: - The ground truth label of audio events, could be many hot - Returns: - Return the binary crossentropy loss - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - - # Compute encoder outputs - encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) - - # Forward the speaker module - logits = self.forward_audio_tagging( - encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) # (N, num_classes) - - loss = self.criterion(logits, target) - - return loss - - def forward_audio_tagging(self, encoder_out, encoder_out_lens): - """ - Args: - encoder_out: - A 3-D tensor of shape (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - - Returns: - A 3-D tensor of shape (N, num_classes). - """ - logits = self.classifier(encoder_out) # (N, T, num_classes) - padding_mask = make_pad_mask(encoder_out_lens) - logits[padding_mask] = 0 - logits = logits.sum(dim=1) # mask the padding frames - logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( - logits - ) # normalize the logits - - return logits diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py deleted file mode 100755 index 8de60bbb5..000000000 --- a/egs/audioset/AT/zipformer/onnx_pretrained.py +++ /dev/null @@ -1,227 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. - -Usage of this script: - - repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 - repo=$(basename $repo_url) - git clone $repo_url - pushd $repo - git lfs pull --include "*.onnx" - popd - - for m in model.onnx model.int8.onnx; do - python3 zipformer/onnx_pretrained.py \ - --model-filename $repo/model.onnx \ - --label-dict $repo/class_labels_indices.csv \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav \ - $repo/test_wavs/3.wav \ - $repo/test_wavs/4.wav - done -""" - -import argparse -import csv -import logging -import math -from typing import List - -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the onnx model. ", - ) - - parser.add_argument( - "--label-dict", - type=str, - help="""class_labels_indices.csv.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - nn_model: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.init_model(nn_model) - - def init_model(self, nn_model: str): - self.model = ort.InferenceSession( - nn_model, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - meta = self.model.get_modelmeta().custom_metadata_map - print(meta) - - def __call__( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a Tensor: - - probs, its shape is (N, num_classes) - """ - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - nn_model=args.model_filename, - ) - - # get the label dictionary - label_dict = {} - with open(args.label_dict, "r") as f: - reader = csv.reader(f, delimiter=",") - for i, row in enumerate(reader): - if i == 0: - continue - label_dict[int(row[0])] = row[2] - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - probs = model(features, feature_lengths) - - for filename, prob in zip(args.sound_files, probs): - topk_prob, topk_index = prob.topk(5) - topk_labels = [label_dict[index.item()] for index in topk_index] - logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with " - f"probability of {topk_prob.tolist()}" - ) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/audioset/AT/zipformer/optim.py b/egs/audioset/AT/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/audioset/AT/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py deleted file mode 100755 index bdbd799fa..000000000 --- a/egs/audioset/AT/zipformer/pretrained.py +++ /dev/null @@ -1,202 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -Note: This is an example for the AudioSet dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -Usage of this script: - - repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 - repo=$(basename $repo_url) - GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url - pushd $repo/exp - git lfs pull --include pretrained.pt - popd - - python3 zipformer/pretrained.py \ - --checkpoint $repo/exp/pretrained.pt \ - --label-dict $repo/data/class_labels_indices.csv \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav \ - $repo/test_wavs/3.wav \ - $repo/test_wavs/4.wav -""" - - -import argparse -import csv -import logging -import math -from typing import List - -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--label-dict", - type=str, - help="""class_labels_indices.csv.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - # get the label dictionary - label_dict = {} - with open(params.label_dict, "r") as f: - reader = csv.reader(f, delimiter=",") - for i, row in enumerate(reader): - if i == 0: - continue - label_dict[int(row[0])] = row[2] - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward and predict the audio events - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) - - for filename, logit in zip(args.sound_files, logits): - topk_prob, topk_index = logit.sigmoid().topk(5) - topk_labels = [label_dict[index.item()] for index in topk_index] - logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with " - f"probability of {topk_prob.tolist()}" - ) - - logging.info("Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/audioset/AT/zipformer/scaling.py b/egs/audioset/AT/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/audioset/AT/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/scaling_converter.py b/egs/audioset/AT/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/audioset/AT/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/subsampling.py b/egs/audioset/AT/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/audioset/AT/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py deleted file mode 100644 index 67c703364..000000000 --- a/egs/audioset/AT/zipformer/train.py +++ /dev/null @@ -1,1194 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - - -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --audioset-subset full \ - --max-duration 1000 - - -""" - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from at_datamodule import AudioSetATDatamodule -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AudioTaggingModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model. Do not recommend to use this for AT", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--num-events", type=int, default=527, help="Number of sound events" - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def _str2modulelist(s: str, add_dot: bool = True): - if add_dot: - return [ss.strip() + "." for ss in s.split(",")] if s is not None else None - else: - return [ss.strip() for ss in s.split(",")] if s is not None else None - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_model(params: AttributeDict) -> nn.Module: - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - model = AudioTaggingModel( - encoder_embed=encoder_embed, - encoder=encoder, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - num_events=params.num_events, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.AudioTaggingDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - events = supervisions[ - "audio_event" - ] # the label indices are in CED format (https://github.com/RicherMans/CED) - labels, _ = str2multihot(events, n_classes=params.num_events) - labels = labels.to(device) - - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - with torch.set_grad_enabled(is_training): - loss = model( - x=feature, - x_lens=feature_lens, - target=labels, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def str2multihot(events: List[str], n_classes=527, id_mapping=None): - # Convert strings separated by semi-colon to multi-hot class labels - # input: ["0;1", "1;2"] - # output: torch.tensor([[1,1,0], [0,1,1]]) - labels = [list(map(int, event.split(";"))) for event in events] - batch_size = len(labels) - out = torch.zeros(batch_size, n_classes) - - for i, label in enumerate(labels): - if id_mapping is not None: - label = [id_mapping[lb] for lb in label] - out[i, label] = 1 - - return out, labels - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - num_samples = 0 - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = batch["inputs"].size(0) - num_samples += batch_size - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - if num_samples > params.num_samples: - logging.info( - f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch" - ) - break - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs( - model, - lr=params.base_lr, - include_names=True, - ), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - audioset = AudioSetATDatamodule(args) - train_cuts = audioset.audioset_train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 30.0: - return False - - return True - - if not params.weighted_sampler: - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = audioset.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = audioset.audioset_eval_cuts() - valid_dl = audioset.valid_dataloaders(valid_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.AudioTaggingDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch( - batch, - params=params, - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AudioSetATDatamodule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/audioset/AT/zipformer/zipformer.py b/egs/audioset/AT/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/audioset/AT/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/.gitignore b/egs/baker_zh/TTS/.gitignore deleted file mode 100644 index 6441cd500..000000000 --- a/egs/baker_zh/TTS/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -path.sh -*.onnx -*.wav -generator_v1 -generator_v2 -generator_v3 diff --git a/egs/baker_zh/TTS/README.md b/egs/baker_zh/TTS/README.md deleted file mode 100644 index 7120c6f79..000000000 --- a/egs/baker_zh/TTS/README.md +++ /dev/null @@ -1,146 +0,0 @@ -# Introduction - -It is for the dataset from -https://en.data-baker.com/datasets/freeDatasets/ - -The dataset contains 10000 Chinese sentences of a native Chinese female speaker, -which is about 12 hours. - - -**Note**: The dataset is for non-commercial use only. - - -# matcha - -[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS) - -Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27). -The pull-request for this recipe can be found at - -The training command is given below: -```bash -python3 ./matcha/train.py \ - --exp-dir ./matcha/exp-1/ \ - --num-workers 4 \ - --world-size 1 \ - --num-epochs 2000 \ - --max-duration 1200 \ - --bucketing-sampler 1 \ - --start-epoch 1 -``` - -To inference, use: - -```bash -# Download Hifigan vocoder. We use Hifigan v2 below. You can select from v1, v2, or v3 - -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 - -python3 ./matcha/infer.py \ - --epoch 2000 \ - --exp-dir ./matcha/exp-1 \ - --vocoder ./generator_v2 \ - --tokens ./data/tokens.txt \ - --cmvn ./data/fbank/cmvn.json \ - --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ - --output-wav ./generated.wav -``` - -```bash -soxi ./generated.wav -``` - -prints: -``` -Input File : './generated.wav' -Channels : 1 -Sample Rate : 22050 -Precision : 16-bit -Duration : 00:00:17.31 = 381696 samples ~ 1298.29 CDDA sectors -File Size : 763k -Bit Rate : 353k -Sample Encoding: 16-bit Signed Integer PCM -``` - -https://github.com/user-attachments/assets/88d4e88f-ebc4-4f32-b216-16d46b966024 - - -To export the checkpoint to onnx: -```bash -python3 ./matcha/export_onnx.py \ - --exp-dir ./matcha/exp-1 \ - --epoch 2000 \ - --tokens ./data/tokens.txt \ - --cmvn ./data/fbank/cmvn.json -``` - -The above command generates the following files: -``` --rw-r--r-- 1 kuangfangjun root 72M Dec 27 18:53 model-steps-2.onnx --rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-3.onnx --rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-4.onnx --rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:55 model-steps-5.onnx --rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:57 model-steps-6.onnx -``` - -where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. - -**HINT**: If you get the following error while running `export_onnx.py`: - -``` -torch.onnx.errors.UnsupportedOperatorError: Exporting the operator -'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported. -``` - -please use `torch>=2.2.0`. - -To export the Hifigan vocoder to onnx, please use: - -```bash -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 - -python3 ./matcha/export_onnx_hifigan.py -``` - -The above command generates 3 files: - - - hifigan_v1.onnx - - hifigan_v2.onnx - - hifigan_v3.onnx - -**HINT**: You can download pre-exported hifigan ONNX models from - - -To use the generated onnx files to generate speech from text, please run: - -```bash - -# First, generate ./lexicon.txt -python3 ./matcha/generate_lexicon.py - -python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-4.onnx \ - --vocoder ./hifigan_v2.onnx \ - --tokens ./data/tokens.txt \ - --lexicon ./lexicon.txt \ - --input-text "在一个阳光明媚的夏天,小马、小羊和小狗它们一块儿在广阔的草地上,嬉戏玩耍,这时小猴来了,还带着它心爱的足球活蹦乱跳地跑前、跑后教小马、小羊、小狗踢足球。" \ - --output-wav ./1.wav -``` - -```bash -soxi ./1.wav - -Input File : './1.wav' -Channels : 1 -Sample Rate : 22050 -Precision : 16-bit -Duration : 00:00:16.37 = 360960 samples ~ 1227.76 CDDA sectors -File Size : 722k -Bit Rate : 353k -Sample Encoding: 16-bit Signed Integer PCM -``` - -https://github.com/user-attachments/assets/578d04bb-fee8-47e5-9984-a868dcce610e - diff --git a/egs/baker_zh/TTS/local/audio.py b/egs/baker_zh/TTS/local/audio.py deleted file mode 120000 index b70d91c92..000000000 --- a/egs/baker_zh/TTS/local/audio.py +++ /dev/null @@ -1 +0,0 @@ -../matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py deleted file mode 100755 index 0720158f2..000000000 --- a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the baker-zh dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from fbank import MatchaFbank, MatchaFbankConfig -from lhotse import CutSet, LilcomChunkyWriter, load_manifest -from lhotse.audio import RecordingSet -from lhotse.supervision import SupervisionSet - -from icefall.utils import get_executor - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--num-jobs", - type=int, - default=4, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - return parser - - -def compute_fbank_baker_zh(num_jobs: int): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - if num_jobs < 1: - num_jobs = os.cpu_count() - - logging.info(f"num_jobs: {num_jobs}") - logging.info(f"src_dir: {src_dir}") - logging.info(f"output_dir: {output_dir}") - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=22050, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - - prefix = "baker_zh" - suffix = "jsonl.gz" - - extractor = MatchaFbank(config) - - with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts.{suffix}" - logging.info(f"Processing {cuts_filename}") - cut_set = load_manifest(src_dir / cuts_filename).resample(22050) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - # Torch's multithreaded behavior needs to be disabled or - # it wastes a lot of CPU and slow things down. - # Do this outside of main() in case it needs to take effect - # even when we are not invoking the main (e.g. when spawning subprocesses). - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_parser().parse_args() - compute_fbank_baker_zh(args.num_jobs) diff --git a/egs/baker_zh/TTS/local/compute_fbank_statistics.py b/egs/baker_zh/TTS/local/compute_fbank_statistics.py deleted file mode 120000 index fd1d8b52e..000000000 --- a/egs/baker_zh/TTS/local/compute_fbank_statistics.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/local/compute_fbank_statistics.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/convert_text_to_tokens.py b/egs/baker_zh/TTS/local/convert_text_to_tokens.py deleted file mode 100755 index bf59cb466..000000000 --- a/egs/baker_zh/TTS/local/convert_text_to_tokens.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import re -from typing import List - -import jieba -from lhotse import load_manifest -from pypinyin import Style, lazy_pinyin, load_phrases_dict - -load_phrases_dict( - { - "行长": [["hang2"], ["zhang3"]], - "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], - } -) - -whiter_space_re = re.compile(r"\s+") - -punctuations_re = [ - (re.compile(x[0], re.IGNORECASE), x[1]) - for x in [ - (",", ","), - ("。", "."), - ("!", "!"), - ("?", "?"), - ("“", '"'), - ("”", '"'), - ("‘", "'"), - ("’", "'"), - (":", ":"), - ("、", ","), - ("B", "逼"), - ("P", "批"), - ] -] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--in-file", - type=str, - required=True, - help="Input cutset.", - ) - - parser.add_argument( - "--out-file", - type=str, - required=True, - help="Output cutset.", - ) - - return parser - - -def normalize_white_spaces(text): - return whiter_space_re.sub(" ", text) - - -def normalize_punctuations(text): - for regex, replacement in punctuations_re: - text = re.sub(regex, replacement, text) - return text - - -def split_text(text: str) -> List[str]: - """ - Example input: '你好呀,You are 一个好人。 去银行存钱?How about you?' - Example output: ['你好', '呀', ',', 'you are', '一个', '好人', '.', '去', '银行', '存钱', '?', 'how about you', '?'] - """ - text = text.lower() - text = normalize_white_spaces(text) - text = normalize_punctuations(text) - ans = [] - - for seg in jieba.cut(text): - if seg in ",.!?:\"'": - ans.append(seg) - elif seg == " " and len(ans) > 0: - if ord("a") <= ord(ans[-1][-1]) <= ord("z"): - ans[-1] += seg - elif ord("a") <= ord(seg[0]) <= ord("z"): - if len(ans) == 0: - ans.append(seg) - continue - - if ans[-1][-1] == " ": - ans[-1] += seg - continue - - ans.append(seg) - else: - ans.append(seg) - - ans = [s.strip() for s in ans] - return ans - - -def main(): - args = get_parser().parse_args() - cuts = load_manifest(args.in_file) - for c in cuts: - assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions) - text = c.supervisions[0].normalized_text - - text_list = split_text(text) - tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True) - - c.tokens = tokens - - cuts.to_file(args.out_file) - - print(f"saved to {args.out_file}") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/local/fbank.py b/egs/baker_zh/TTS/local/fbank.py deleted file mode 120000 index 5bcf1fde5..000000000 --- a/egs/baker_zh/TTS/local/fbank.py +++ /dev/null @@ -1 +0,0 @@ -../matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/generate_tokens.py b/egs/baker_zh/TTS/local/generate_tokens.py deleted file mode 100755 index b2abe1a71..000000000 --- a/egs/baker_zh/TTS/local/generate_tokens.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file generates the file tokens.txt. - -Usage: - -python3 ./local/generate_tokens.py > data/tokens.txt -""" - - -import argparse -from typing import List - -import jieba -from pypinyin import Style, lazy_pinyin, pinyin_dict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--tokens", - type=str, - required=True, - help="Path to to save tokens.txt.", - ) - - return parser - - -def generate_token_list() -> List[str]: - token_set = set() - - word_dict = pinyin_dict.pinyin_dict - i = 0 - for key in word_dict: - if not (0x4E00 <= key <= 0x9FFF): - continue - - w = chr(key) - t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] - token_set.add(t) - - no_digit = set() - for t in token_set: - if t[-1] not in "1234": - no_digit.add(t) - else: - no_digit.add(t[:-1]) - - no_digit.add("dei") - no_digit.add("tou") - no_digit.add("dia") - - for t in no_digit: - token_set.add(t) - for i in range(1, 5): - token_set.add(f"{t}{i}") - - ans = list(token_set) - ans.sort() - - punctuations = list(",.!?:\"'") - ans = punctuations + ans - - # use ID 0 for blank - # Use ID 1 of _ for padding - ans.insert(0, " ") - ans.insert(1, "_") # - - return ans - - -def main(): - args = get_parser().parse_args() - token_list = generate_token_list() - with open(args.tokens, "w", encoding="utf-8") as f: - for indx, token in enumerate(token_list): - f.write(f"{token} {indx}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py deleted file mode 100755 index 4e31028f7..000000000 --- a/egs/baker_zh/TTS/local/validate_manifest.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/spectrogram/baker_zh_cuts_all.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset.speech_synthesis import validate_for_tts - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet), type(cut_set) - - validate_for_tts(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/baker_zh/TTS/matcha/__init__.py b/egs/baker_zh/TTS/matcha/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/baker_zh/TTS/matcha/audio.py b/egs/baker_zh/TTS/matcha/audio.py deleted file mode 120000 index 62d3959d6..000000000 --- a/egs/baker_zh/TTS/matcha/audio.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/export_onnx.py b/egs/baker_zh/TTS/matcha/export_onnx.py deleted file mode 100755 index 28efbfe61..000000000 --- a/egs/baker_zh/TTS/matcha/export_onnx.py +++ /dev/null @@ -1,207 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -""" -This script exports a Matcha-TTS model to ONNX. -Note that the model outputs fbank. You need to use a vocoder to convert -it to audio. See also ./export_onnx_hifigan.py - -python3 ./matcha/export_onnx.py \ - --exp-dir ./matcha/exp-1 \ - --epoch 2000 \ - --tokens ./data/tokens.txt \ - --cmvn ./data/fbank/cmvn.json - -""" - -import argparse -import json -import logging -from pathlib import Path -from typing import Any, Dict - -import onnx -import torch -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=2000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp-new-3", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, Any]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - - while len(model.metadata_props): - model.metadata_props.pop() - - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -class ModelWrapper(torch.nn.Module): - def __init__(self, model, num_steps: int = 5): - super().__init__() - self.model = model - self.num_steps = num_steps - - def forward( - self, - x: torch.Tensor, - x_lengths: torch.Tensor, - noise_scale: torch.Tensor, - length_scale: torch.Tensor, - ) -> torch.Tensor: - """ - Args: : - x: (batch_size, num_tokens), torch.int64 - x_lengths: (batch_size,), torch.int64 - noise_scale: (1,), torch.float32 - length_scale (1,), torch.float32 - Returns: - audio: (batch_size, num_samples) - - """ - mel = self.model.synthesise( - x=x, - x_lengths=x_lengths, - n_timesteps=self.num_steps, - temperature=noise_scale, - length_scale=length_scale, - )["mel"] - # mel: (batch_size, feat_dim, num_frames) - - return mel - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.pad_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - for num_steps in [2, 3, 4, 5, 6]: - logging.info(f"num_steps: {num_steps}") - wrapper = ModelWrapper(model, num_steps=num_steps) - wrapper.eval() - - # Use a large value so the rotary position embedding in the text - # encoder has a large initial length - x = torch.ones(1, 1000, dtype=torch.int64) - x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - noise_scale = torch.tensor([1.0]) - length_scale = torch.tensor([1.0]) - - opset_version = 14 - filename = f"model-steps-{num_steps}.onnx" - torch.onnx.export( - wrapper, - (x, x_lengths, noise_scale, length_scale), - filename, - opset_version=opset_version, - input_names=["x", "x_length", "noise_scale", "length_scale"], - output_names=["mel"], - dynamic_axes={ - "x": {0: "N", 1: "L"}, - "x_length": {0: "N"}, - "mel": {0: "N", 2: "L"}, - }, - ) - - meta_data = { - "model_type": "matcha-tts", - "language": "Chinese", - "has_espeak": 0, - "n_speakers": 1, - "jieba": 1, - "sample_rate": 22050, - "version": 1, - "pad_id": params.pad_id, - "model_author": "icefall", - "maintainer": "k2-fsa", - "dataset": "baker-zh", - "use_eos_bos": 0, - "dataset_url": "https://www.data-baker.com/open_source.html", - "dataset_comment": "The dataset is for non-commercial use only.", - "num_ode_steps": num_steps, - } - add_meta_data(filename=filename, meta_data=meta_data) - print(meta_data) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py deleted file mode 120000 index d0b8af15b..000000000 --- a/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/export_onnx_hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/fbank.py b/egs/baker_zh/TTS/matcha/fbank.py deleted file mode 120000 index 3cfb7fe3f..000000000 --- a/egs/baker_zh/TTS/matcha/fbank.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/generate_lexicon.py b/egs/baker_zh/TTS/matcha/generate_lexicon.py deleted file mode 100755 index f26f28e91..000000000 --- a/egs/baker_zh/TTS/matcha/generate_lexicon.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 - -import jieba -from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict -from tokenizer import Tokenizer - -load_phrases_dict( - { - "行长": [["hang2"], ["zhang3"]], - "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], - } -) - - -def main(): - filename = "lexicon.txt" - tokens = "./data/tokens.txt" - tokenizer = Tokenizer(tokens) - - word_dict = pinyin_dict.pinyin_dict - phrases = phrases_dict.phrases_dict - - i = 0 - with open(filename, "w", encoding="utf-8") as f: - for key in word_dict: - if not (0x4E00 <= key <= 0x9FFF): - continue - - w = chr(key) - tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] - - f.write(f"{w} {tokens}\n") - - for key in phrases: - tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True) - tokens = " ".join(tokens) - - f.write(f"{key} {tokens}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/matcha/hifigan b/egs/baker_zh/TTS/matcha/hifigan deleted file mode 120000 index c0a91072c..000000000 --- a/egs/baker_zh/TTS/matcha/hifigan +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/hifigan \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/infer.py b/egs/baker_zh/TTS/matcha/infer.py deleted file mode 100755 index b90c2fdbd..000000000 --- a/egs/baker_zh/TTS/matcha/infer.py +++ /dev/null @@ -1,342 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) -""" -python3 ./matcha/infer.py \ - --epoch 2000 \ - --exp-dir ./matcha/exp-1 \ - --vocoder ./generator_v2 \ - --tokens ./data/tokens.txt \ - --cmvn ./data/fbank/cmvn.json \ - --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ - --output-wav ./generated.wav -""" - -import argparse -import datetime as dt -import json -import logging -from pathlib import Path - -import soundfile as sf -import torch -import torch.nn as nn -from hifigan.config import v1, v2, v3 -from hifigan.denoiser import Denoiser -from hifigan.models import Generator as HiFiGAN -from local.convert_text_to_tokens import split_text -from pypinyin import Style, lazy_pinyin -from tokenizer import Tokenizer -from train import get_model, get_params -from tts_datamodule import BakerZhTtsDataModule - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, setup_logger - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=4000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--vocoder", - type=Path, - default="./generator_v1", - help="Path to the vocoder", - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - # The following arguments are used for inference on single text - parser.add_argument( - "--input-text", - type=str, - required=False, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=False, - help="The filename of the wave to save the generated speech", - ) - - parser.add_argument( - "--sampling-rate", - type=int, - default=22050, - help="The sampling rate of the generated speech (default: 22050 for baker_zh)", - ) - - return parser - - -def load_vocoder(checkpoint_path: Path) -> nn.Module: - checkpoint_path = str(checkpoint_path) - if checkpoint_path.endswith("v1"): - h = AttributeDict(v1) - elif checkpoint_path.endswith("v2"): - h = AttributeDict(v2) - elif checkpoint_path.endswith("v3"): - h = AttributeDict(v3) - else: - raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") - - hifigan = HiFiGAN(h).to("cpu") - hifigan.load_state_dict( - torch.load(checkpoint_path, map_location="cpu")["generator"] - ) - _ = hifigan.eval() - hifigan.remove_weight_norm() - return hifigan - - -def to_waveform( - mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module -) -> torch.Tensor: - audio = vocoder(mel).clamp(-1, 1) - audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() - return audio.squeeze() - - -def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: - text = split_text(text) - tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True) - - x = tokenizer.texts_to_token_ids([tokens]) - x = torch.tensor(x, dtype=torch.long, device=device) - x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) - return {"x_orig": text, "x": x, "x_lengths": x_lengths} - - -def synthesize( - model: nn.Module, - tokenizer: Tokenizer, - n_timesteps: int, - text: str, - length_scale: float, - temperature: float, - device: str = "cpu", - spks=None, -) -> dict: - text_processed = process_text(text=text, tokenizer=tokenizer, device=device) - start_t = dt.datetime.now() - output = model.synthesise( - text_processed["x"], - text_processed["x_lengths"], - n_timesteps=n_timesteps, - temperature=temperature, - spks=spks, - length_scale=length_scale, - ) - # merge everything to one dict - output.update({"start_t": start_t, **text_processed}) - return output - - -def infer_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - vocoder: nn.Module, - denoiser: nn.Module, - tokenizer: Tokenizer, -) -> None: - """Decode dataset. - The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - tokenizer: - Used to convert text to phonemes. - """ - - device = next(model.parameters()).device - num_cuts = 0 - log_interval = 5 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - for batch_idx, batch in enumerate(dl): - batch_size = len(batch["tokens"]) - - texts = [c.supervisions[0].normalized_text for c in batch["cut"]] - - audio = batch["audio"] - audio_lens = batch["audio_lens"].tolist() - cut_ids = [cut.id for cut in batch["cut"]] - - for i in range(batch_size): - output = synthesize( - model=model, - tokenizer=tokenizer, - n_timesteps=params.n_timesteps, - text=texts[i], - length_scale=params.length_scale, - temperature=params.temperature, - device=device, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write( - file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", - data=output["waveform"], - samplerate=params.data_args.sampling_rate, - subtype="PCM_16", - ) - sf.write( - file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", - data=audio[i].numpy(), - samplerate=params.data_args.sampling_rate, - subtype="PCM_16", - ) - - num_cuts += batch_size - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - -@torch.inference_mode() -def main(): - parser = get_parser() - BakerZhTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.suffix = f"epoch-{params.epoch}" - - params.res_dir = params.exp_dir / "infer" / params.suffix - params.save_wav_dir = params.res_dir / "wav" - params.save_wav_dir.mkdir(parents=True, exist_ok=True) - - setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") - logging.info("Infer started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - - # Number of ODE Solver steps - params.n_timesteps = 2 - - # Changes to the speaking rate - params.length_scale = 1.0 - - # Sampling temperature - params.temperature = 0.667 - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - model.eval() - - # we need cut ids to organize tts results. - args.return_cuts = True - baker_zh = BakerZhTtsDataModule(args) - - test_cuts = baker_zh.test_cuts() - test_dl = baker_zh.test_dataloaders(test_cuts) - - if not Path(params.vocoder).is_file(): - raise ValueError(f"{params.vocoder} does not exist") - - vocoder = load_vocoder(params.vocoder) - vocoder.to(device) - - denoiser = Denoiser(vocoder, mode="zeros") - denoiser.to(device) - - if params.input_text is not None and params.output_wav is not None: - logging.info("Synthesizing a single text") - output = synthesize( - model=model, - tokenizer=tokenizer, - n_timesteps=params.n_timesteps, - text=params.input_text, - length_scale=params.length_scale, - temperature=params.temperature, - device=device, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write( - file=params.output_wav, - data=output["waveform"], - samplerate=params.sampling_rate, - subtype="PCM_16", - ) - else: - logging.info("Decoding the test set") - infer_dataset( - dl=test_dl, - params=params, - model=model, - vocoder=vocoder, - denoiser=denoiser, - tokenizer=tokenizer, - ) - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/matcha/model.py b/egs/baker_zh/TTS/matcha/model.py deleted file mode 120000 index 8a1b812a9..000000000 --- a/egs/baker_zh/TTS/matcha/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/model.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/models b/egs/baker_zh/TTS/matcha/models deleted file mode 120000 index 09a862665..000000000 --- a/egs/baker_zh/TTS/matcha/models +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/models \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/monotonic_align b/egs/baker_zh/TTS/matcha/monotonic_align deleted file mode 120000 index d0a0dd6b5..000000000 --- a/egs/baker_zh/TTS/matcha/monotonic_align +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/monotonic_align \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/onnx_pretrained.py b/egs/baker_zh/TTS/matcha/onnx_pretrained.py deleted file mode 100755 index f6b7f7cae..000000000 --- a/egs/baker_zh/TTS/matcha/onnx_pretrained.py +++ /dev/null @@ -1,316 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -""" -python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-4.onnx \ - --vocoder ./hifigan_v2.onnx \ - --tokens ./data/tokens.txt \ - --lexicon ./lexicon.txt \ - --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ - --output-wav ./b.wav -""" - -import argparse -import datetime as dt -import logging -import re -from typing import Dict, List - -import jieba -import onnxruntime as ort -import soundfile as sf -import torch -from infer import load_vocoder -from utils import intersperse - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--acoustic-model", - type=str, - required=True, - help="Path to the acoustic model", - ) - - parser.add_argument( - "--tokens", - type=str, - required=True, - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--lexicon", - type=str, - required=True, - help="Path to the lexicon.txt", - ) - - parser.add_argument( - "--vocoder", - type=str, - required=True, - help="Path to the vocoder", - ) - - parser.add_argument( - "--input-text", - type=str, - required=True, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=True, - help="The filename of the wave to save the generated speech", - ) - - return parser - - -class OnnxHifiGANModel: - def __init__( - self, - filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - self.model = ort.InferenceSession( - filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - for i in self.model.get_inputs(): - print(i) - - print("-----") - - for i in self.model.get_outputs(): - print(i) - - def __call__(self, x: torch.tensor): - assert x.ndim == 3, x.shape - assert x.shape[0] == 1, x.shape - - audio = self.model.run( - [self.model.get_outputs()[0].name], - { - self.model.get_inputs()[0].name: x.numpy(), - }, - )[0] - # audio: (batch_size, num_samples) - - return torch.from_numpy(audio) - - -class OnnxModel: - def __init__( - self, - filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 2 - - self.session_opts = session_opts - self.model = ort.InferenceSession( - filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - metadata = self.model.get_modelmeta().custom_metadata_map - self.sample_rate = int(metadata["sample_rate"]) - - for i in self.model.get_inputs(): - print(i) - - print("-----") - - for i in self.model.get_outputs(): - print(i) - - def __call__(self, x: torch.tensor): - assert x.ndim == 2, x.shape - assert x.shape[0] == 1, x.shape - - x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - print("x_lengths", x_lengths) - print("x", x.shape) - - noise_scale = torch.tensor([1.0], dtype=torch.float32) - length_scale = torch.tensor([1.0], dtype=torch.float32) - - mel = self.model.run( - [self.model.get_outputs()[0].name], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lengths.numpy(), - self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: length_scale.numpy(), - }, - )[0] - # mel: (batch_size, feat_dim, num_frames) - - return torch.from_numpy(mel) - - -def read_tokens(filename: str) -> Dict[str, int]: - token2id = dict() - with open(filename, encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split() - if len(info) == 1: - # case of space - token = " " - idx = int(info[0]) - else: - token, idx = info[0], int(info[1]) - assert token not in token2id, token - token2id[token] = idx - return token2id - - -def read_lexicon(filename: str) -> Dict[str, List[str]]: - word2token = dict() - with open(filename, encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split() - w = info[0] - tokens = info[1:] - word2token[w] = tokens - return word2token - - -def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]: - if word in word2tokens: - return word2tokens[word] - - if len(word) == 1: - return [] - - ans = [] - for w in word: - t = convert_word_to_tokens(word2tokens, w) - ans.extend(t) - return ans - - -def normalize_text(text): - whiter_space_re = re.compile(r"\s+") - - punctuations_re = [ - (re.compile(x[0], re.IGNORECASE), x[1]) - for x in [ - (",", ","), - ("。", "."), - ("!", "!"), - ("?", "?"), - ("“", '"'), - ("”", '"'), - ("‘", "'"), - ("’", "'"), - (":", ":"), - ("、", ","), - ] - ] - - for regex, replacement in punctuations_re: - text = re.sub(regex, replacement, text) - return text - - -@torch.no_grad() -def main(): - params = get_parser().parse_args() - logging.info(vars(params)) - token2id = read_tokens(params.tokens) - word2tokens = read_lexicon(params.lexicon) - - text = normalize_text(params.input_text) - seg = jieba.cut(text) - tokens = [] - for s in seg: - if s in token2id: - tokens.append(s) - continue - - t = convert_word_to_tokens(word2tokens, s) - if t: - tokens.extend(t) - - model = OnnxModel(params.acoustic_model) - vocoder = OnnxHifiGANModel(params.vocoder) - - x = [] - for t in tokens: - if t in token2id: - x.append(token2id[t]) - - x = intersperse(x, item=token2id["_"]) - - x = torch.tensor(x, dtype=torch.int64).unsqueeze(0) - - start_t = dt.datetime.now() - mel = model(x) - end_t = dt.datetime.now() - - start_t2 = dt.datetime.now() - audio = vocoder(mel) - end_t2 = dt.datetime.now() - - print("audio", audio.shape) # (1, 1, num_samples) - audio = audio.squeeze() - - sample_rate = model.sample_rate - - t = (end_t - start_t).total_seconds() - t2 = (end_t2 - start_t2).total_seconds() - rtf_am = t * sample_rate / audio.shape[-1] - rtf_vocoder = t2 * sample_rate / audio.shape[-1] - print("RTF for acoustic model ", rtf_am) - print("RTF for vocoder", rtf_vocoder) - - # skip denoiser - sf.write(params.output_wav, audio, sample_rate, "PCM_16") - logging.info(f"Saved to {params.output_wav}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() - -""" - -|HifiGAN |RTF |#Parameters (M)| -|----------|-----|---------------| -|v1 |0.818| 13.926 | -|v2 |0.101| 0.925 | -|v3 |0.118| 1.462 | - -|Num steps|Acoustic Model RTF| -|---------|------------------| -| 2 | 0.039 | -| 3 | 0.047 | -| 4 | 0.071 | -| 5 | 0.076 | -| 6 | 0.103 | - -""" diff --git a/egs/baker_zh/TTS/matcha/tokenizer.py b/egs/baker_zh/TTS/matcha/tokenizer.py deleted file mode 100644 index dda82c29d..000000000 --- a/egs/baker_zh/TTS/matcha/tokenizer.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import logging -from typing import Dict, List - -import tacotron_cleaner.cleaners - -try: - from piper_phonemize import phonemize_espeak -except Exception as ex: - raise RuntimeError( - f"{ex}\nPlease run\n" - "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" - ) - -from utils import intersperse - - -# This tokenizer supports both English and Chinese. -# We assume you have used -# ../local/convert_text_to_tokens.py -# to process your text -class Tokenizer(object): - def __init__(self, tokens: str): - """ - Args: - tokens: the file that maps tokens to ids - """ - # Parse token file - self.token2id: Dict[str, int] = {} - with open(tokens, "r", encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split() - if len(info) == 1: - # case of space - token = " " - id = int(info[0]) - else: - token, id = info[0], int(info[1]) - assert token not in self.token2id, token - self.token2id[token] = id - - # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md - self.pad_id = self.token2id["_"] # padding - self.space_id = self.token2id[" "] # word separator (whitespace) - - self.vocab_size = len(self.token2id) - - def texts_to_token_ids( - self, - sentence_list: List[List[str]], - intersperse_blank: bool = True, - lang: str = "en-us", - ) -> List[List[int]]: - """ - Args: - sentence_list: - A list of sentences. - intersperse_blank: - Whether to intersperse blanks in the token sequence. - lang: - Language argument passed to phonemize_espeak(). - - Returns: - Return a list of token id list [utterance][token_id] - """ - token_ids_list = [] - - for sentence in sentence_list: - tokens_list = [] - for word in sentence: - if word in self.token2id: - tokens_list.append(word) - continue - - tmp_tokens_list = phonemize_espeak(word, lang) - for t in tmp_tokens_list: - tokens_list.extend(t) - - token_ids = [] - for t in tokens_list: - if t not in self.token2id: - logging.warning(f"Skip OOV {t} {sentence}") - continue - - if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id: - continue - - token_ids.append(self.token2id[t]) - - if intersperse_blank: - token_ids = intersperse(token_ids, self.pad_id) - - token_ids_list.append(token_ids) - - return token_ids_list - - -def test_tokenizer(): - import jieba - from pypinyin import Style, lazy_pinyin - - tokenizer = Tokenizer("data/tokens.txt") - text1 = "今天is Monday, tomorrow is 星期二" - text2 = "你好吗? 我很好, how about you?" - - text1 = list(jieba.cut(text1)) - text2 = list(jieba.cut(text2)) - tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True) - tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True) - print(tokens1) - print(tokens2) - - ids = tokenizer.texts_to_token_ids([tokens1, tokens2]) - print(ids) - - -if __name__ == "__main__": - test_tokenizer() diff --git a/egs/baker_zh/TTS/matcha/train.py b/egs/baker_zh/TTS/matcha/train.py deleted file mode 100755 index ed2ba49b9..000000000 --- a/egs/baker_zh/TTS/matcha/train.py +++ /dev/null @@ -1,717 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - - -import argparse -import json -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Union - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse.utils import fix_random_seed -from model import fix_len_compatibility -from models.matcha_tts import MatchaTTS -from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import BakerZhTtsDataModule -from utils import MetricsTracker - -from icefall.checkpoint import load_checkpoint, save_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12335, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1000, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=10, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_data_statistics(): - return AttributeDict( - { - "mel_mean": 0, - "mel_std": 1, - } - ) - - -def _get_data_params() -> AttributeDict: - params = AttributeDict( - { - "name": "baker-zh", - "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", - "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - # "batch_size": 64, - # "num_workers": 1, - # "pin_memory": False, - "cleaners": ["english_cleaners2"], - "add_blank": True, - "n_spks": 1, - "n_fft": 1024, - "n_feats": 80, - "sampling_rate": 22050, - "hop_length": 256, - "win_length": 1024, - "f_min": 0, - "f_max": 8000, - "seed": 1234, - "load_durations": False, - "data_statistics": get_data_statistics(), - } - ) - return params - - -def _get_model_params() -> AttributeDict: - n_feats = 80 - filter_channels_dp = 256 - encoder_params_p_dropout = 0.1 - params = AttributeDict( - { - "n_spks": 1, # for baker-zh. - "spk_emb_dim": 64, - "n_feats": n_feats, - "out_size": None, # or use 172 - "prior_loss": True, - "use_precomputed_durations": False, - "data_statistics": get_data_statistics(), - "encoder": AttributeDict( - { - "encoder_type": "RoPE Encoder", # not used - "encoder_params": AttributeDict( - { - "n_feats": n_feats, - "n_channels": 192, - "filter_channels": 768, - "filter_channels_dp": filter_channels_dp, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - "spk_emb_dim": 64, - "n_spks": 1, - "prenet": True, - } - ), - "duration_predictor_params": AttributeDict( - { - "filter_channels_dp": filter_channels_dp, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - } - ), - } - ), - "decoder": AttributeDict( - { - "channels": [256, 256], - "dropout": 0.05, - "attention_head_dim": 64, - "n_blocks": 1, - "num_mid_blocks": 2, - "num_heads": 2, - "act_fn": "snakebeta", - } - ), - "cfm": AttributeDict( - { - "name": "CFM", - "solver": "euler", - "sigma_min": 1e-4, - } - ), - "optimizer": AttributeDict( - { - "lr": 1e-4, - "weight_decay": 0.0, - } - ), - } - ) - - return params - - -def get_params(): - params = AttributeDict( - { - "model_args": _get_model_params(), - "data_args": _get_data_params(), - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 10, - "valid_interval": 1500, - "env_info": get_env_info(), - } - ) - return params - - -def get_model(params): - m = MatchaTTS(**params.model_args) - return m - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): - """Parse batch data""" - mel_mean = params.data_args.data_statistics.mel_mean - mel_std_inv = 1 / params.data_args.data_statistics.mel_std - for i in range(batch["features"].shape[0]): - n = batch["features_lens"][i] - batch["features"][i : i + 1, :n, :] = ( - batch["features"][i : i + 1, :n, :] - mel_mean - ) * mel_std_inv - batch["features"][i : i + 1, n:, :] = 0 - - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - tokens = batch["tokens"] - - tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - max_feature_length = fix_len_compatibility(features.shape[1]) - if max_feature_length > features.shape[1]: - pad = max_feature_length - features.shape[1] - features = torch.nn.functional.pad(features, (0, 0, 0, pad)) - - # features_lens[features_lens.argmax()] += pad - - return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, - rank: int = 0, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) - - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) - - batch_size = len(batch["tokens"]) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - s = 0 - - for key, value in losses.items(): - v = value.detach().item() - loss_info[key] = v * batch_size - s += v * batch_size - - loss_info["tot_loss"] = s - - # summary stats - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["tot_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - optimizer: Optimizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses - - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer=optimizer, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - # audio: (N, T), float32 - # features: (N, T, C), float32 - # audio_lens, (N,), int32 - # features_lens, (N,), int32 - # tokens: List[List[str]], len(tokens) == N - - batch_size = len(batch["tokens"]) - - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) - try: - with autocast(enabled=params.use_fp16): - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) - - loss = sum(losses.values()) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - s = 0 - - for key, value in losses.items(): - v = value.detach().item() - loss_info[key] = v * batch_size - s += v * batch_size - - loss_info["tot_loss"] = s - - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. - # The _growth_interval of the grad scaler is configurable, - # but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, " - f"batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if params.batch_idx_train % params.valid_interval == 1: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - rank=rank, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - "Maximum memory allocated so far is " - f"{torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["tot_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.pad_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - - logging.info(params) - print(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of parameters: {num_param}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) - - logging.info("About to create datamodule") - - baker_zh = BakerZhTtsDataModule(args) - - train_cuts = baker_zh.train_cuts() - train_dl = baker_zh.train_dataloaders(train_cuts) - - valid_cuts = baker_zh.valid_cuts() - valid_dl = baker_zh.valid_dataloaders(valid_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - fix_random_seed(params.seed + epoch - 1) - if "sampler" in train_dl: - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - optimizer=optimizer, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer=optimizer, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - BakerZhTtsDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/baker_zh/TTS/matcha/tts_datamodule.py b/egs/baker_zh/TTS/matcha/tts_datamodule.py deleted file mode 100644 index d2bdfb96c..000000000 --- a/egs/baker_zh/TTS/matcha/tts_datamodule.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from fbank import MatchaFbank, MatchaFbankConfig -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class BakerZhTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - pin_memory=True, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=True, - pin_memory=True, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" - ) diff --git a/egs/baker_zh/TTS/matcha/utils.py b/egs/baker_zh/TTS/matcha/utils.py deleted file mode 120000 index ceaaea196..000000000 --- a/egs/baker_zh/TTS/matcha/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh deleted file mode 100755 index e15e3d850..000000000 --- a/egs/baker_zh/TTS/prepare.sh +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download -mkdir -p $dl_dir - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: build monotonic_align lib (used by ./matcha)" - for recipe in matcha; do - if [ ! -d $recipe/monotonic_align/build ]; then - cd $recipe/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib for $recipe already built" - fi - done -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # The directory $dl_dir/BANSYP contains the following 3 directories - - # ls -lh $dl_dir/BZNSYP/ - # total 0 - # drwxr-xr-x 10002 kuangfangjun root 0 Jan 4 2019 PhoneLabeling - # drwxr-xr-x 3 kuangfangjun root 0 Jan 31 2019 ProsodyLabeling - # drwxr-xr-x 10003 kuangfangjun root 0 Aug 26 17:45 Wave - - # If you have trouble accessing huggingface.co, please use - # - # cd $dl_dir - # wget https://huggingface.co/openspeech/BZNSYP/resolve/main/BZNSYP.tar.bz2 - # tar xf BZNSYP.tar.bz2 - # cd .. - - # If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink - # - # ln -sfv /path/to/BZNSYP $dl_dir/BZNSYP - # - if [ ! -d $dl_dir/BZNSYP/Wave ]; then - lhotse download baker-zh $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare baker-zh manifest" - # We assume that you have downloaded the baker corpus - # to $dl_dir/BZNSYP - mkdir -p data/manifests - if [ ! -e data/manifests/.baker-zh.done ]; then - lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests - touch data/manifests/.baker-zh.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Generate tokens.txt" - if [ ! -e data/tokens.txt ]; then - python3 ./local/generate_tokens.py --tokens data/tokens.txt - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Generate raw cutset" - if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then - lhotse cut simple \ - -r ./data/manifests/baker_zh_recordings_all.jsonl.gz \ - -s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \ - ./data/manifests/baker_zh_cuts_raw.jsonl.gz - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Convert text to tokens" - if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then - python3 ./local/convert_text_to_tokens.py \ - --in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \ - --out-file ./data/manifests/baker_zh_cuts.jsonl.gz - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Generate fbank (used by ./matcha)" - mkdir -p data/fbank - if [ ! -e data/fbank/.baker-zh.done ]; then - ./local/compute_fbank_baker_zh.py - touch data/fbank/.baker-zh.done - fi - - if [ ! -e data/fbank/.baker-zh-validated.done ]; then - log "Validating data/fbank for baker-zh (used by ./matcha)" - python3 ./local/validate_manifest.py \ - data/fbank/baker_zh_cuts.jsonl.gz - touch data/fbank/.baker-zh-validated.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)" - if [ ! -e data/fbank/.baker_zh_split.done ]; then - lhotse subset --last 600 \ - data/fbank/baker_zh_cuts.jsonl.gz \ - data/fbank/baker_zh_cuts_validtest.jsonl.gz - lhotse subset --first 100 \ - data/fbank/baker_zh_cuts_validtest.jsonl.gz \ - data/fbank/baker_zh_cuts_valid.jsonl.gz - lhotse subset --last 500 \ - data/fbank/baker_zh_cuts_validtest.jsonl.gz \ - data/fbank/baker_zh_cuts_test.jsonl.gz - - rm data/fbank/baker_zh_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 )) - - lhotse subset --first $n \ - data/fbank/baker_zh_cuts.jsonl.gz \ - data/fbank/baker_zh_cuts_train.jsonl.gz - - touch data/fbank/.baker_zh_split.done - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 6: Compute fbank mean and std (used by ./matcha)" - if [ ! -f ./data/fbank/cmvn.json ]; then - ./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json - fi -fi diff --git a/egs/baker_zh/TTS/shared b/egs/baker_zh/TTS/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/baker_zh/TTS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/commonvoice/ASR/README.md b/egs/commonvoice/ASR/README.md deleted file mode 100644 index a4582499b..000000000 --- a/egs/commonvoice/ASR/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# Introduction - -This recipe includes some different ASR models trained with Common Voice - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|---------------------------------------------------| -| `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan | - -The decoder in `transducer_stateless` is modified from the paper -[RNN-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md deleted file mode 100644 index f384f66a0..000000000 --- a/egs/commonvoice/ASR/RESULTS.md +++ /dev/null @@ -1,151 +0,0 @@ -## Results - -### Commonvoice Cantonese (zh-HK) Char training results (Zipformer) - -See #1546 for more details. - -Number of model parameters: 72526519, i.e., 72.53 M - -The best CER, for CommonVoice 16.1 (cv-corpus-16.1-2023-12-06/zh-HK) is below: - -| | Dev | Test | Note | -|----------------------|-------|------|--------------------| -| greedy_search | 1.17 | 1.22 | --epoch 24 --avg 5 | -| modified_beam_search | 0.98 | 1.11 | --epoch 24 --avg 5 | -| fast_beam_search | 1.08 | 1.27 | --epoch 24 --avg 5 | - -When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (w/o blank penalty), -the best CER is below: - -| | Dev | Test | Note | -|----------------------|-------|------|--------------------| -| greedy_search | 42.40 | 42.03| --epoch 24 --avg 5 | -| modified_beam_search | 39.73 | 39.19| --epoch 24 --avg 5 | -| fast_beam_search | 42.14 | 41.98| --epoch 24 --avg 5 | - -When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (with blank penalty set to 2.2), -the best CER is below: - -| | Dev | Test | Note | -|----------------------|-------|------|----------------------------------------| -| greedy_search | 39.19 | 39.09| --epoch 24 --avg 5 --blank-penalty 2.2 | -| modified_beam_search | 37.73 | 37.65| --epoch 24 --avg 5 --blank-penalty 2.2 | -| fast_beam_search | 37.73 | 37.74| --epoch 24 --avg 5 --blank-penalty 2.2 | - -To reproduce the above result, use the following commands for training: - -```bash -export CUDA_VISIBLE_DEVICES="0,1" -./zipformer/train_char.py \ - --world-size 2 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --cv-manifest-dir data/zh-HK/fbank \ - --language zh-HK \ - --use-validated-set 1 \ - --context-size 1 \ - --max-duration 1000 -``` - -and the following commands for decoding: - -```bash -for method in greedy_search modified_beam_search fast_beam_search; do - ./zipformer/decode_char.py \ - --epoch 24 \ - --avg 5 \ - --decoding-method $method \ - --exp-dir zipformer/exp \ - --cv-manifest-dir data/zh-HK/fbank \ - --context-size 1 \ - --language zh-HK -done -``` - -Detailed experimental results and pre-trained model are available at: - - - -### CommonVoice English (en) BPE training results (Pruned Stateless Transducer 7) - -#### [pruned_transducer_stateless7](./pruned_transducer_stateless7) - -See #997 for more details. - -Number of model parameters: 70369391, i.e., 70.37 M - -Note that the result is obtained using GigaSpeech transcript trained BPE model - -The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below: - -Results are: - -| | Dev | Test | -|----------------------|-------|-------| -| greedy_search | 9.96 | 12.54 | -| modified_beam_search | 9.86 | 12.48 | - -To reproduce the above result, use the following commands for training: - -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --max-duration 550 -``` - -and the following commands for decoding: - -```bash -# greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 30 \ - --avg 5 \ - --decoding-method greedy_search \ - --exp-dir pruned_transducer_stateless7/exp \ - --bpe-model data/en/lang_bpe_500/bpe.model \ - --max-duration 600 - -# modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 30 \ - --avg 5 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --exp-dir pruned_transducer_stateless7/exp \ - --bpe-model data/en/lang_bpe_500/bpe.model \ - --max-duration 600 -``` - -Pretrained model is available at - - -### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming) - -#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) - -See #1018 for more details. - -Number of model parameters: 70369391, i.e., 70.37 M - -The best WER for Common Voice French 12.0 (cv-corpus-12.0-2022-12-07/fr) is below: - -Results are: - -| decoding method | Test | -|----------------------|-------| -| greedy_search | 9.95 | -| modified_beam_search | 9.57 | -| fast_beam_search | 9.67 | - -Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice. - -Detailed experimental results and Pretrained model are available at - - diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py deleted file mode 100755 index 6512aa68b..000000000 --- a/egs/commonvoice/ASR/local/compile_hlg.py +++ /dev/null @@ -1,168 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_n_gram.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lm", - type=str, - default="G_3_gram", - help="""Stem name for LM used in HLG compiling. - """, - ) - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - lm: - The language stem base name. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): - logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"{lang_dir}/lm/{lm}.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info(f"Loading {lm}.fst.txt") - with open(f"{lang_dir}/lm/{lm}.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir, args.lm) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py deleted file mode 100755 index 76dacb5b2..000000000 --- a/egs/commonvoice/ASR/local/compile_lg.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Kang Wei, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates LG from - - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from lang_dir/lm/G_3_gram.fst.txt - -The generated LG is saved in $lang_dir/LG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - parser.add_argument( - "--lm", - type=str, - default="G_3_gram", - help="""Stem name for LM used in HLG compiling. - """, - ) - - return parser.parse_args() - - -def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing LG. - """ - lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): - logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"{lang_dir}/lm/{lm}.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info(f"Loading {lm}.fst.txt") - with open(f"{lang_dir}/lm/{lm}.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - return LG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "LG.pt").is_file(): - logging.info(f"{lang_dir}/LG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - LG = compile_LG(lang_dir, args.lm) - logging.info(f"Saving LG.pt to {lang_dir}") - torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py deleted file mode 100755 index a0b4d224c..000000000 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the CommonVoice dataset. -It looks for manifests in the directory data/${lang}/manifests. - -The generated fbank features are saved in data/${lang}/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import torch -from filter_cuts import filter_cuts -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--language", - type=str, - help="""Language of Common Voice""", - ) - - return parser.parse_args() - - -def compute_fbank_commonvoice_dev_test(language: str): - src_dir = Path(f"data/{language}/manifests") - output_dir = Path(f"data/{language}/fbank") - num_workers = 16 - batch_duration = 200 - - subsets = ("dev", "test") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - - logging.info(f"device: {device}") - - for partition in subsets: - cuts_path = output_dir / f"cv-{language}_cuts_{partition}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - - raw_cuts_path = output_dir / f"cv-{language}_cuts_{partition}_raw.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/cv-{language}_feats_{partition}", - num_workers=num_workers, - batch_duration=batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - compute_fbank_commonvoice_dev_test(language=args.language) diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py deleted file mode 100755 index aa672609a..000000000 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py +++ /dev/null @@ -1,178 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023-2024 Xiaomi Corp. (Yifan Yang, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomChunkyWriter, - set_audio_duration_mismatch_tolerance, - set_caching_enabled, -) - -from icefall.utils import str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--subset", - type=str, - default="train", - choices=["train", "validated", "invalidated"], - help="""Dataset parts to compute fbank. """, - ) - - parser.add_argument( - "--language", - type=str, - help="""Language of Common Voice""", - ) - - parser.add_argument( - "--num-workers", - type=int, - default=20, - help="Number of dataloading workers used for reading the audio.", - ) - - parser.add_argument( - "--batch-duration", - type=float, - default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", - ) - - parser.add_argument( - "--num-splits", - type=int, - required=True, - help="The number of splits of the subset", - ) - - parser.add_argument( - "--start", - type=int, - default=0, - help="Process pieces starting from this number (included).", - ) - - parser.add_argument( - "--stop", - type=int, - default=-1, - help="Stop processing pieces until this number (excluded).", - ) - - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", - ) - - return parser.parse_args() - - -def compute_fbank_commonvoice_splits(args): - subset = args.subset - num_splits = args.num_splits - language = args.language - output_dir = f"data/{language}/fbank/cv-{language}_{subset}_split_{num_splits}" - output_dir = Path(output_dir) - assert output_dir.exists(), f"{output_dir} does not exist!" - - num_digits = len(str(num_splits)) - - start = args.start - stop = args.stop - if stop < start: - stop = num_splits - - stop = min(stop, num_splits) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - logging.info(f"device: {device}") - - set_audio_duration_mismatch_tolerance(0.05) # 50ms tolerance - set_caching_enabled(False) - for i in range(start, stop): - idx = f"{i}".zfill(num_digits) - logging.info(f"Processing {idx}/{num_splits}") - - cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = output_dir / f"cv-{language}_cuts_{subset}_raw.{idx}.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - if args.perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/cv-{language}_feats_{subset}_{idx}", - num_workers=args.num_workers, - batch_duration=args.batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - compute_fbank_commonvoice_splits(args) - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/local/compute_fbank_musan.py b/egs/commonvoice/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/commonvoice/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/filter_cuts.py b/egs/commonvoice/ASR/local/filter_cuts.py deleted file mode 120000 index 27aca1729..000000000 --- a/egs/commonvoice/ASR/local/filter_cuts.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/prepare_char.py b/egs/commonvoice/ASR/local/prepare_char.py deleted file mode 120000 index 42743b544..000000000 --- a/egs/commonvoice/ASR/local/prepare_char.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/prepare_lang.py b/egs/commonvoice/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/commonvoice/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/prepare_lang_bpe.py b/egs/commonvoice/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/commonvoice/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/prepare_lang_fst.py b/egs/commonvoice/ASR/local/prepare_lang_fst.py deleted file mode 120000 index c5787c534..000000000 --- a/egs/commonvoice/ASR/local/prepare_lang_fst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py deleted file mode 100755 index cc88ef8d7..000000000 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ /dev/null @@ -1,181 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import re -from pathlib import Path -from typing import Optional - -from lhotse import CutSet -from lhotse.recipes.utils import read_manifests_if_cached - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - ) - - parser.add_argument( - "--language", - type=str, - help="""Language of Common Voice""", - ) - - return parser.parse_args() - - -def normalize_text(utt: str, language: str) -> str: - utt = re.sub(r"[{0}]+".format("-"), " ", utt) - utt = re.sub("’", "'", utt) - if language == "en": - return re.sub(r"[^a-zA-Z\s]", "", utt).upper() - elif language == "fr": - return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() - elif language == "pl": - return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper() - elif language in ["yue", "zh-HK"]: - # Mozilla Common Voice uses both "yue" and "zh-HK" for Cantonese - # Not sure why they decided to do this... - # None en/zh-yue tokens are manually removed here - - # fmt: off - tokens_to_remove = [",", "。", "?", "!", "?", "!", "‘", "、", ",", "\.", ":", ";", "「", "」", "“", "”", "~", "—", "ㄧ", "《", "》", "…", "⋯", "·", "﹒", ".", ":", "︰", "﹖", "(", ")", "-", "~", ";", "", "⠀", "﹔", "/", "A", "B", "–", "‧"] - - # fmt: on - utt = utt.upper().replace("\\", "") - return re.sub( - pattern="|".join([f"[{token}]" for token in tokens_to_remove]), - repl="", - string=utt, - ) - else: - raise NotImplementedError( - f""" - Text normalization not implemented for language: {language}, - please consider implementing it in the local/preprocess_commonvoice.py - or raise an issue on GitHub to request it. - """ - ) - - -def preprocess_commonvoice( - language: str, - dataset: Optional[str] = None, -): - src_dir = Path(f"data/{language}/manifests") - output_dir = Path(f"data/{language}/fbank") - output_dir.mkdir(exist_ok=True) - - if dataset is None: - dataset_parts = ( - "dev", - "test", - "train", - ) - else: - dataset_parts = dataset.split(" ", -1) - - logging.info("Loading manifest") - prefix = f"cv-{language}" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - suffix=suffix, - prefix=prefix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - for partition, m in manifests.items(): - logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" - if raw_cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping") - continue - - logging.info(f"Normalizing text in {partition}") - for sup in m["supervisions"]: - text = str(sup.text) - orig_text = text - sup.text = normalize_text(sup.text, language) - text = str(sup.text) - if len(orig_text) != len(text): - logging.info( - f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" - ) - - # Create long-recording cut manifests. - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ).resample(16000) - - if partition == "validated": - logging.warning( - """ - The 'validated' partition contains the data of both 'train', 'dev' - and 'test' partitions. We filter out the 'dev' and 'test' partition - here. - """ - ) - dev_ids = src_dir / f"cv-{language}_dev_ids" - test_ids = src_dir / f"cv-{language}_test_ids" - assert ( - dev_ids.is_file() - ), f"{dev_ids} does not exist, please check stage 1 of the prepare.sh" - assert ( - test_ids.is_file() - ), f"{test_ids} does not exist, please check stage 1 of the prepare.sh" - dev_ids = dev_ids.read_text().strip().split("\n") - test_ids = test_ids.read_text().strip().split("\n") - cut_set = cut_set.filter( - lambda x: x.supervisions[0].id not in dev_ids + test_ids - ) - - # Run data augmentation that needs to be done in the - # time domain. - logging.info(f"Saving to {raw_cuts_path}") - cut_set.to_file(raw_cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - preprocess_commonvoice( - language=args.language, - dataset=args.dataset, - ) - logging.info("Done") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/local/train_bpe_model.py b/egs/commonvoice/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/commonvoice/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/validate_bpe_lexicon.py b/egs/commonvoice/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/commonvoice/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/word_segment_yue.py b/egs/commonvoice/ASR/local/word_segment_yue.py deleted file mode 100755 index 35d262d10..000000000 --- a/egs/commonvoice/ASR/local/word_segment_yue.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes a text file "data/lang_char/text" as input, the file consist of -lines each containing a transcript, applies text norm and generates the following -files in the directory "data/lang_char": - - transcript_words.txt - - words.txt - - words_no_ids.txt -""" - -import argparse -import logging -import re -from pathlib import Path -from typing import List - -import pycantonese -from preprocess_commonvoice import normalize_text -from tqdm.auto import tqdm - -from icefall.utils import is_cjk, tokenize_by_CJK_char - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare char lexicon", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - "-i", - default="data/yue/lang_char/text", - type=str, - help="The input text file", - ) - parser.add_argument( - "--output-dir", - "-o", - default="data/yue/lang_char/", - type=str, - help="The output directory", - ) - parser.add_argument( - "--lang", - "-l", - default="yue", - type=str, - help="The language", - ) - return parser - - -def get_word_segments(lines: List[str]) -> List[str]: - # the current pycantonese segmenter does not handle the case when the input - # is code switching, so we need to handle it separately - - new_lines = [] - - for line in tqdm(lines, desc="Segmenting lines"): - try: - if is_cs(line): # code switching - segments = [] - curr_str = "" - for segment in tokenize_by_CJK_char(line).split(" "): - if segment.strip() == "": - continue - try: - if not is_cjk(segment[0]): # en segment - if curr_str: - segments.extend(pycantonese.segment(curr_str)) - curr_str = "" - segments.append(segment) - else: # zh segment - curr_str += segment - # segments.extend(pycantonese.segment(segment)) - except Exception as e: - logging.error(f"Failed to process segment: {segment}") - raise - if curr_str: # process the last segment - segments.extend(pycantonese.segment(curr_str)) - new_lines.append(" ".join(segments) + "\n") - else: # not code switching - new_lines.append(" ".join(pycantonese.segment(line)) + "\n") - except Exception as e: - logging.error(f"Failed to process line: {line}") - raise e - return new_lines - - -def get_words(lines: List[str]) -> List[str]: - words = set() - for line in tqdm(lines, desc="Getting words"): - words.update(line.strip().split(" ")) - return list(words) - - -def is_cs(line: str) -> bool: - english_markers = r"[a-zA-Z]+" - return bool(re.search(english_markers, line)) - - -if __name__ == "__main__": - parser = get_parser() - args = parser.parse_args() - - input_file = Path(args.input_file) - output_dir = Path(args.output_dir) - lang = args.lang - - assert input_file.is_file(), f"{input_file} does not exist" - assert output_dir.is_dir(), f"{output_dir} does not exist" - - lines = input_file.read_text(encoding="utf-8").strip().split("\n") - norm_lines = [normalize_text(line, lang) for line in lines] - - text_words_segments = get_word_segments(norm_lines) - with open(output_dir / "transcript_words.txt", "w", encoding="utf-8") as f: - f.writelines(text_words_segments) - - words = get_words(text_words_segments)[1:] # remove "\n" from words - with open(output_dir / "words_no_ids.txt", "w", encoding="utf-8") as f: - f.writelines([word + "\n" for word in sorted(words)]) - - words = ( - ["", "!SIL", "", ""] - + sorted(words) - + ["#0", "", "<\s>"] - ) - - with open(output_dir / "words.txt", "w", encoding="utf-8") as f: - f.writelines([f"{word} {i}\n" for i, word in enumerate(words)]) diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh deleted file mode 100755 index 200114a86..000000000 --- a/egs/commonvoice/ASR/prepare.sh +++ /dev/null @@ -1,508 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -nj=16 -stage=-1 -stop_stage=100 - -# Split data/${lang}set to this number of pieces -# This is to avoid OOM during feature extraction. -num_splits=1000 - -# In case you want to use all validated data -use_validated=false - -# In case you are willing to take the risk and use invalidated data -use_invalidated=false - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/$release/$lang -# This directory contains the following files downloaded from -# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz -# -# - clips -# - dev.tsv -# - invalidated.tsv -# - other.tsv -# - reported.tsv -# - test.tsv -# - train.tsv -# - validated.tsv -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download -release=cv-corpus-12.0-2022-12-07 -lang=fr -perturb_speed=false - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/${lang}/lang_bpe_xxx, -# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - # 5000 - # 2000 - # 1000 - 500 -) - -# All files generated by this script are saved in "data/${lang}". -# You can safely remove "data/${lang}" and rerun this script to regenerate it. -mkdir -p data/${lang} - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if ! command -v ffmpeg &> /dev/null; then - echo "This dataset requires ffmpeg" - echo "Please install ffmpeg first" - echo "" - echo " sudo apt-get install ffmpeg" - exit 1 -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/$release, - # you can create a symlink - # - # ln -sfv /path/to/$release $dl_dir/$release - # - if [ ! -d $dl_dir/$release/$lang/clips ]; then - lhotse download commonvoice --languages $lang --release $release $dl_dir - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare CommonVoice manifest" - # We assume that you have downloaded the CommonVoice corpus - # to $dl_dir/$release - mkdir -p data/${lang}/manifests - if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then - lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests - - if [ $use_validated = true ] && [ ! -f data/${lang}/manifests/.cv-${lang}.validated.done ]; then - log "Also prepare validated data" - lhotse prepare commonvoice \ - --split validated \ - --language $lang \ - -j $nj $dl_dir/$release data/${lang}/manifests - touch data/${lang}/manifests/.cv-${lang}.validated.done - fi - - if [ $use_invalidated = true ] && [ ! -f data/${lang}/manifests/.cv-${lang}.invalidated.done ]; then - log "Also prepare invalidated data" - lhotse prepare commonvoice \ - --split invalidated \ - --language $lang \ - -j $nj $dl_dir/$release data/${lang}/manifests - touch data/${lang}/manifests/.cv-${lang}.invalidated.done - fi - - touch data/${lang}/manifests/.cv-${lang}.done - fi - - # Note: in Linux, you can install jq with the following command: - # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - # 2. chmod +x ./jq - # 3. cp jq /usr/bin - if [ $use_validated = true ]; then - log "Getting cut ids from dev/test sets for later use" - gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_test.jsonl.gz \ - | jq '.id' | sed 's/"//g' > data/${lang}/manifests/cv-${lang}_test_ids - - gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_dev.jsonl.gz \ - | jq '.id' | sed 's/"//g' > data/${lang}/manifests/cv-${lang}_dev_ids - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - mkdir -p data/manifests - if [ ! -e data/manifests/.musan.done ]; then - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Preprocess CommonVoice manifest" - if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then - ./local/preprocess_commonvoice.py --language $lang - touch data/${lang}/fbank/.preprocess_complete - fi - - if [ $use_validated = true ] && [ ! -f data/${lang}/fbank/.validated.preprocess_complete ]; then - log "Also preprocess validated data" - ./local/preprocess_commonvoice.py --language $lang --dataset validated - touch data/${lang}/fbank/.validated.preprocess_complete - fi - - if [ $use_invalidated = true ] && [ ! -f data/${lang}/fbank/.invalidated.preprocess_complete ]; then - log "Also preprocess invalidated data" - ./local/preprocess_commonvoice.py --language $lang --dataset invalidated - touch data/${lang}/fbank/.invalidated.preprocess_complete - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for dev and test subsets of CommonVoice" - mkdir -p data/${lang}/fbank - if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then - ./local/compute_fbank_commonvoice_dev_test.py --language $lang - touch data/${lang}/fbank/.cv-${lang}_dev_test.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Split train subset into ${num_splits} pieces" - split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits} - if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then - lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir - touch $split_dir/.cv-${lang}_train_split.done - fi - - split_dir=data/${lang}/fbank/cv-${lang}_validated_split_${num_splits} - if [ $use_validated = true ] && [ ! -f $split_dir/.cv-${lang}_validated.done ]; then - log "Also split validated data" - lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_validated_raw.jsonl.gz $split_dir - touch $split_dir/.cv-${lang}_validated.done - fi - - split_dir=data/${lang}/fbank/cv-${lang}_invalidated_split_${num_splits} - if [ $use_invalidated = true ] && [ ! -f $split_dir/.cv-${lang}_invalidated.done ]; then - log "Also split invalidated data" - lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_invalidated_raw.jsonl.gz $split_dir - touch $split_dir/.cv-${lang}_invalidated.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Compute features for train subset of CommonVoice" - if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then - ./local/compute_fbank_commonvoice_splits.py \ - --num-workers $nj \ - --batch-duration 200 \ - --start 0 \ - --num-splits $num_splits \ - --language $lang \ - --perturb-speed $perturb_speed - touch data/${lang}/fbank/.cv-${lang}_train.done - fi - - if [ $use_validated = true ] && [ ! -f data/${lang}/fbank/.cv-${lang}_validated.done ]; then - log "Also compute features for validated data" - ./local/compute_fbank_commonvoice_splits.py \ - --subset validated \ - --num-workers $nj \ - --batch-duration 200 \ - --start 0 \ - --num-splits $num_splits \ - --language $lang \ - --perturb-speed $perturb_speed - touch data/${lang}/fbank/.cv-${lang}_validated.done - fi - - if [ $use_invalidated = true ] && [ ! -f data/${lang}/fbank/.cv-${lang}_invalidated.done ]; then - log "Also compute features for invalidated data" - ./local/compute_fbank_commonvoice_splits.py \ - --subset invalidated \ - --num-workers $nj \ - --batch-duration 200 \ - --start 0 \ - --num-splits $num_splits \ - --language $lang \ - --perturb-speed $perturb_speed - touch data/${lang}/fbank/.cv-${lang}_invalidated.done - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Combine features for train" - if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then - pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz") - lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz - fi - - if [ $use_validated = true ] && [ -f data/${lang}/fbank/.cv-${lang}_validated.done ]; then - log "Also combine features for validated data" - pieces=$(find data/${lang}/fbank/cv-${lang}_validated_split_${num_splits} -name "cv-${lang}_cuts_validated.*.jsonl.gz") - lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_validated.jsonl.gz - touch data/${lang}/fbank/.cv-${lang}_validated.done - fi - - if [ $use_invalidated = true ] && [ -f data/${lang}/fbank/.cv-${lang}_invalidated.done ]; then - log "Also combine features for invalidated data" - pieces=$(find data/${lang}/fbank/cv-${lang}_invalidated_split_${num_splits} -name "cv-${lang}_cuts_invalidated.*.jsonl.gz") - lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_invalidated.jsonl.gz - touch data/${lang}/fbank/.cv-${lang}_invalidated.done - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compute fbank for musan" - mkdir -p data/fbank - if [ ! -e data/fbank/.musan.done ]; then - ./local/compute_fbank_musan.py - touch data/fbank/.musan.done - fi -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then - log "Stage 9: Prepare Char based lang" - lang_dir=data/${lang}/lang_char/ - mkdir -p $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for lang preparation" - - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - # 2. chmod +x ./jq - # 3. cp jq /usr/bin - if [ $use_validated = true ]; then - gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_validated.jsonl.gz \ - | jq '.text' | sed 's/"//g' >> $lang_dir/text - else - gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_train.jsonl.gz \ - | jq '.text' | sed 's/"//g' > $lang_dir/text - fi - - if [ $use_invalidated = true ]; then - gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_invalidated.jsonl.gz \ - | jq '.text' | sed 's/"//g' >> $lang_dir/text - fi - - if [ $lang == "yue" ] || [ $lang == "zh-HK" ]; then - # Get words.txt and words_no_ids.txt - ./local/word_segment_yue.py \ - --input-file $lang_dir/text \ - --output-dir $lang_dir \ - --lang $lang - - mv $lang_dir/text $lang_dir/_text - cp $lang_dir/transcript_words.txt $lang_dir/text - - if [ ! -f $lang_dir/tokens.txt ]; then - ./local/prepare_char.py --lang-dir $lang_dir - fi - else - log "word_segment_${lang}.py not implemented yet" - exit 1 - fi - fi - else - log "Stage 9: Prepare BPE based lang" - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} - mkdir -p $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - file=$( - find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz" - ) - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - # 2. chmod +x ./jq - # 3. cp jq /usr/bin - gunzip -c ${file} \ - | jq '.supervisions[].text' | sed 's/"//g' > $lang_dir/transcript_words.txt - - # Ensure space only appears once - sed -i 's/\t/ /g' $lang_dir/transcript_words.txt - sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/words.txt ]; then - cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_dir/words.txt - (echo '!SIL'; echo ''; echo ''; ) | - cat - $lang_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_dir/words || exit 1; - mv $lang_dir/words $lang_dir/words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done - fi -fi - -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Prepare G" - # We assume you have install kaldilm, if not, please install - # it using: pip install kaldilm - - if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then - lang_dir=data/${lang}/lang_char - mkdir -p $lang_dir/lm - - for ngram in 3 ; do - if [ ! -f $lang_dir/lm/${ngram}-gram.unpruned.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order ${ngram} \ - -text $lang_dir/transcript_words.txt \ - -lm $lang_dir/lm/${ngram}gram.unpruned.arpa - fi - - if [ ! -f $lang_dir/lm/G_${ngram}_gram_char.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=${ngram} \ - $lang_dir/lm/${ngram}gram.unpruned.arpa \ - > $lang_dir/lm/G_${ngram}_gram_char.fst.txt - fi - - if [ ! -f $lang_dir/lm/HLG.fst ]; then - ./local/prepare_lang_fst.py \ - --lang-dir $lang_dir \ - --ngram-G $lang_dir/lm/G_${ngram}_gram_char.fst.txt - fi - done - else - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} - mkdir -p $lang_dir/lm - #3-gram used in building HLG, 4-gram used for LM rescoring - for ngram in 3 4; do - if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order ${ngram} \ - -text $lang_dir/transcript_words.txt \ - -lm $lang_dir/lm/${ngram}gram.arpa - fi - - if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=${ngram} \ - $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt - fi - done - done - fi -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Compile HLG" - - if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then - lang_dir=data/${lang}/lang_char - for ngram in 3 ; do - if [ ! -f $lang_dir/lm/HLG_${ngram}.fst ]; then - ./local/compile_hlg.py --lang-dir $lang_dir --lm G_${ngram}_gram_char - fi - done - else - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir - done - fi -fi - -# Compile LG for RNN-T fast_beam_search decoding -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Compile LG" - - if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then - lang_dir=data/${lang}/lang_char - for ngram in 3 ; do - if [ ! -f $lang_dir/lm/LG_${ngram}.fst ]; then - ./local/compile_lg.py --lang-dir $lang_dir --lm G_${ngram}_gram_char - fi - done - else - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/${lang}/lang_bpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir - done - fi -fi diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py deleted file mode 100644 index a80cfe85e..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ /dev/null @@ -1,440 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class CommonVoiceAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. CommonVoice test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--language", - type=str, - default="en", - help="""Language of Common Voice""", - ) - group.add_argument( - "--cv-manifest-dir", - type=Path, - default=Path("data/en/fbank"), - help="Path to directory with CommonVoice train/dev/test cuts.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with the other cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)() - ), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" - ) - - @lru_cache() - def validated_cuts(self) -> CutSet: - logging.info("About to get validated cuts (with dev/test removed)") - return load_manifest_lazy( - self.args.cv_manifest_dir - / f"cv-{self.args.language}_cuts_validated.jsonl.gz" - ) - - @lru_cache() - def invalidated_cuts(self) -> CutSet: - logging.info("About to get invalidated cuts") - return load_manifest_lazy( - self.args.cv_manifest_dir - / f"cv-{self.args.language}_cuts_invalidated.jsonl.gz" - ) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz" - ) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py deleted file mode 100755 index 52b2fbcab..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py +++ /dev/null @@ -1,962 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(8) modified beam search with RNNLM shallow fusion -./pruned_transducer_stateless5/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search_lm_shallow_fusion \ - --beam-size 4 \ - --lm-type rnn \ - --lm-scale 0.3 \ - --lm-exp-dir /path/to/LM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 - -(9) modified beam search with LM shallow fusion + LODR -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --max-duration 600 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --decoding-method modified_beam_search_LODR \ - --beam-size 4 \ - --lm-type rnn \ - --lm-scale 0.4 \ - --lm-exp-dir /path/to/LM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 - --tokens-ngram 2 \ - --ngram-lm-scale -0.16 \ - -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, - modified_beam_search_ngram_rescoring, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/en/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/en/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion - - modified_beam_search_LODR - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=3, - help="""Token Ngram used for rescoring. - Used only when the decoding method is - modified_beam_search_ngram_rescoring, or LODR - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="""ID of the backoff symbol. - Used only when the decoding method is - modified_beam_search_ngram_rescoring""", - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - LM: Optional[LmScorer] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` - set to true. - ngram_lm: - A ngram lm. Used in LODR decoding. - ngram_lm_scale: - The scale of the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - LM: Optional[LmScorer] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network LM, used during shallow fusion - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - LM=LM, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if "ngram" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if params.use_shallow_fusion: - if params.lm_type == "rnn": - params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" - elif params.lm_type == "transformer": - params.suffix += f"-transformer-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load N-gram LM when needed - if "ngram" in params.decoding_method or "LODR" in params.decoding_method: - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - # only load the neural network LM if doing shallow fusion - if params.use_shallow_fusion: - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - - else: - LM = None - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - commonvoice = CommonVoiceAsrDataModule(args) - - dev_cuts = commonvoice.dev_cuts() - test_cuts = commonvoice.test_cuts() - - dev_dl = commonvoice.valid_dataloaders(dev_cuts) - test_dl = commonvoice.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - LM=LM, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py deleted file mode 100755 index 2b9f2293a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py +++ /dev/null @@ -1,601 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Yifan Yang) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-9999.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 9999 \ - --avg 1 \ - --exp-dir $repo/exp - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-9999-avg-1.onnx - - decoder-epoch-9999-avg-1.onnx - - joiner-epoch-9999-avg-1.onnx - -See ./onnx_pretrained.py and ./onnx_check.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import onnx -import sentencepiece as spm -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import Zipformer - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/en/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Zipformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - encoder_out, encoder_out_lens = self.encoder(x, x_lens) - - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "zipformer", - "version": "1", - "model_author": "k2-fsa", - "comment": "stateless7", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py deleted file mode 100755 index 53705321e..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/en/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 5 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/en/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 5 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/commonvoice/ASR - ./pruned_transducer_stateless7/decode.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/en/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 - # You will find the pre-trained model in icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17/exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/en/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py deleted file mode 120000 index 0d8bc665b..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py deleted file mode 100755 index f04537660..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. - -We use the pre-trained model from -https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-9999.pt -popd - -2. Export the model via torchscript (torch.jit.script()) - -./pruned_transducer_stateless7/export.py \ - --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ - --epoch 9999 \ - --avg 1 \ - --exp-dir $repo/exp/ \ - --jit 1 - -It will generate the following file in $repo/exp: - - cpu_jit.pt - -3. Export the model to ONNX - -./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ - --epoch 9999 \ - --avg 1 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-9999-avg-1.onnx - - decoder-epoch-9999-avg-1.onnx - - joiner-epoch-9999-avg-1.onnx - -4. Run this file - -./pruned_transducer_stateless7/onnx_check.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ - --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ - --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx -""" - -import argparse -import logging - -import torch -from onnx_pretrained import OnnxModel - -from icefall import is_module_available - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - return parser - - -def test_encoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - C = 80 - for i in range(3): - N = torch.randint(low=1, high=20, size=(1,)).item() - T = torch.randint(low=30, high=50, size=(1,)).item() - logging.info(f"test_encoder: iter {i}, N={N}, T={T}") - - x = torch.rand(N, T, C) - x_lens = torch.randint(low=30, high=T + 1, size=(N,)) - x_lens[0] = T - - torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) - torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) - - onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) - - assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( - (torch_encoder_out - onnx_encoder_out).abs().max() - ) - - -def test_decoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - context_size = onnx_model.context_size - vocab_size = onnx_model.vocab_size - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_decoder: iter {i}, N={N}") - x = torch.randint( - low=1, - high=vocab_size, - size=(N, context_size), - dtype=torch.int64, - ) - torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) - torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) - torch_decoder_out = torch_decoder_out.squeeze(1) - - onnx_decoder_out = onnx_model.run_decoder(x) - assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( - (torch_decoder_out - onnx_decoder_out).abs().max() - ) - - -def test_joiner( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] - decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_joiner: iter {i}, N={N}") - encoder_out = torch.rand(N, encoder_dim) - decoder_out = torch.rand(N, decoder_dim) - - projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) - projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) - - torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) - onnx_joiner_out = onnx_model.run_joiner( - projected_encoder_out, projected_decoder_out - ) - - assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( - (torch_joiner_out - onnx_joiner_out).abs().max() - ) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_model = torch.jit.load(args.jit_filename) - - onnx_model = OnnxModel( - encoder_model_filename=args.onnx_encoder_filename, - decoder_model_filename=args.onnx_decoder_filename, - joiner_model_filename=args.onnx_joiner_filename, - ) - - logging.info("Test encoder") - test_encoder(torch_model, onnx_model) - - logging.info("Test decoder") - test_decoder(torch_model, onnx_model) - - logging.info("Test joiner") - test_joiner(torch_model, onnx_model) - logging.info("Finished checking ONNX models") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -# See https://github.com/pytorch/pytorch/issues/38342 -# and https://github.com/pytorch/pytorch/issues/33354 -# -# If we don't do this, the delay increases whenever there is -# a new request that changes the actual batch size. -# If you use `py-spy dump --pid --native`, you will -# see a lot of time is spent in re-compiling the torch script model. -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py deleted file mode 100755 index 52fed7331..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ /dev/null @@ -1,423 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -We use the pre-trained model from -https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-9999.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ - --epoch 9999 \ - --avg 1 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-9999-avg-1.onnx - - decoder-epoch-9999-avg-1.onnx - - joiner-epoch-9999-avg-1.onnx - -3. Run this file - -./pruned_transducer_stateless7/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ - --tokens $repo/data/en/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import numpy as np -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, joiner_dim) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.run_decoder(decoder_input) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - # current_encoder_out's shape: (batch_size, joiner_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - logits = model.run_joiner(current_encoder_out, decoder_out) - - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = model.run_decoder(decoder_input) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += symbol_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 100755 index b6e2451e8..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,356 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/en/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 5 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/en/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/en/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/en/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/en/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py deleted file mode 100755 index 5e98084ec..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ /dev/null @@ -1,1267 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/en/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--use-validated-set", - type=str2bool, - default=False, - help="""Use the validated set for training. - This is useful when you want to use more data for training, - but not recommended for research purposes. - """, - ) - - parser.add_argument( - "--use-invalidated-set", - type=str2bool, - default=False, - help="""Use the invalidated set for training. - In case you want to take the risk and utilize more data for training. - """, - ) - - parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - commonvoice = CommonVoiceAsrDataModule(args) - - if args.use_validated_set: - train_cuts = commonvoice.validated_cuts() - else: - train_cuts = commonvoice.train_cuts() - - if args.use_invalidated_set: - train_cuts += commonvoice.invalidated_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = commonvoice.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = commonvoice.dev_cuts() - valid_dl = commonvoice.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py deleted file mode 120000 index f2f66041e..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md deleted file mode 100644 index 6c20bab2c..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md +++ /dev/null @@ -1,9 +0,0 @@ -This recipe implements Streaming Zipformer-Transducer model. - -See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. - -[./emformer.py](./emformer.py) and [./train.py](./train.py) -are basically the same as -[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py). -The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py) -is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py deleted file mode 120000 index c274de28a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/asr_datamodule.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py deleted file mode 120000 index d7349b0a3..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py deleted file mode 100755 index 7ae4f1894..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py +++ /dev/null @@ -1,811 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - feature_lens += 30 - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, 30), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - - # errs_info = ( - # params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - # ) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{key}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - commonvoice = CommonVoiceAsrDataModule(args) - - test_cuts = commonvoice.test_cuts() - - test_dl = commonvoice.test_dataloaders(test_cuts) - - test_sets = "test-cv" - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_sets, - results_dict=results_dict, - ) - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py deleted file mode 120000 index ca8fed319..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py deleted file mode 120000 index 33944d0d2..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py deleted file mode 100755 index aefe88f3f..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ /dev/null @@ -1,1257 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer_for_ncnn_export_only import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - is_pnnx=True, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError(f", exiting: {cur_grad_scale}") - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - commonvoice = CommonVoiceAsrDataModule(args) - - train_cuts = commonvoice.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = commonvoice.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = commonvoice.dev_cuts() - valid_dl = commonvoice.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - raise RuntimeError("Please don't use this file directly!") - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py deleted file mode 120000 index cb673b3eb..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py deleted file mode 120000 index 72e43c297..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py deleted file mode 120000 index 3b36924ef..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py deleted file mode 120000 index 57a0cd0a0..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py deleted file mode 120000 index 2acafdc61..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py deleted file mode 100755 index 976004eca..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ /dev/null @@ -1,1342 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_finetune_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--do-finetune", type=str2bool, default=False) - - parser.add_argument( - "--init-modules", - type=str, - default=None, - help=""" - Modules to be initialized. It matches all parameters starting with - a specific key. The keys are given with Comma seperated. If None, - all modules will be initialised. For example, if you only want to - initialise all parameters staring with "encoder", use "encoder"; - if you want to initialise parameters starting with encoder or decoder, - use "encoder,joiner". - """, - ) - - parser.add_argument( - "--finetune-ckpt", - type=str, - default=None, - help="Fine-tuning from which checkpoint (a path to a .pt file)", - ) - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="""Embedding dimension in the 2 blocks of zipformer encoder - layers, comma separated - """, - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers,\ - comma separated; not the same as embedding dimension. - """, - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="""Unmasked dimensions in the encoders, relates to augmentation - during training. Must be <= each of encoder_dims. Empirically, less - than 256 seems to make performance worse. - """, - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="""Path to the BPE model. - This should be the bpe model of the original model - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.005, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=100000, - help="""Number of steps that affects how rapidly the learning rate - decreases. During fine-tuning, we set this very large so that the - learning rate slowly decays with number of batches. You may tune - its value by yourself. - """, - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=100, - help="""Number of epochs that affects how rapidly the learning rate - decreases. During fine-tuning, we set this very large so that the - learning rate slowly decays with number of batches. You may tune - its value by yourself. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - add_finetune_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def load_model_params( - ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True -): - """Load model params from checkpoint - - Args: - ckpt (str): Path to the checkpoint - model (nn.Module): model to be loaded - - """ - logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") - - # if module list is empty, load the whole model from ckpt - if not init_modules: - if next(iter(checkpoint["model"])).startswith("module."): - logging.info("Loading checkpoint saved by DDP") - - dst_state_dict = model.state_dict() - src_state_dict = checkpoint["model"] - for key in dst_state_dict.keys(): - src_key = "{}.{}".format("module", key) - dst_state_dict[key] = src_state_dict.pop(src_key) - assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=strict) - else: - model.load_state_dict(checkpoint["model"], strict=strict) - else: - src_state_dict = checkpoint["model"] - dst_state_dict = model.state_dict() - for module in init_modules: - logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] - assert set(src_keys) == set(dst_keys) # two sets should match exactly - for key in src_keys: - dst_state_dict[key] = src_state_dict.pop(key) - - model.load_state_dict(dst_state_dict, strict=strict) - - return None - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have - # different behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - # load model parameters for model fine-tuning - if params.do_finetune: - modules = params.init_modules.split(",") if params.init_modules else None - checkpoints = load_model_params( - ckpt=params.finetune_ckpt, model=model, init_modules=modules - ) - else: - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - commonvoice = CommonVoiceAsrDataModule(args) - - train_cuts = commonvoice.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = commonvoice.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = commonvoice.dev_cuts() - valid_dl = commonvoice.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments( - parser - ) # you may replace this with your own dataset - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py deleted file mode 100755 index 3fd14aa47..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py +++ /dev/null @@ -1,281 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt -./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ - --epoch 28 \ - --avg 15 \ - --use-averaged-model True \ - --exp-dir ./pruned_transducer_stateless7/exp - -It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. - -(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt -./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ - --iter 22000 \ - --avg 5 \ - --use-averaged-model True \ - --exp-dir ./pruned_transducer_stateless7/exp - -It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. - -(3) use the original model with checkpoint exp_dir/epoch-xxx.pt -./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ - --epoch 28 \ - --avg 15 \ - --use-averaged-model False \ - --exp-dir ./pruned_transducer_stateless7/exp - -It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. - -(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt -./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ - --iter 22000 \ - --avg 5 \ - --use-averaged-model False \ - --exp-dir ./pruned_transducer_stateless7/exp - -It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. -""" - - -import argparse -from pathlib import Path -from typing import Dict, List - -import sentencepiece as spm -import torch -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model." - "If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - print("Script started") - - device = torch.device("cpu") - print(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - print("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - print(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - print(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = ( - params.exp_dir - / f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt" - ) - torch.save({"model": model.state_dict()}, filename) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = ( - params.exp_dir - / f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt" - ) - torch.save({"model": model.state_dict()}, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py deleted file mode 120000 index 5d9c6ba00..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py deleted file mode 120000 index 457131699..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py deleted file mode 120000 index 2b8fa3cbb..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py deleted file mode 120000 index ecfb6dd8a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py deleted file mode 120000 index e17d4f734..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py deleted file mode 120000 index 28bf7bb82..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py deleted file mode 120000 index c8548d459..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py deleted file mode 120000 index ae4d9bb04..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py deleted file mode 120000 index 81ac4a89a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py deleted file mode 120000 index 9510b8fde..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py deleted file mode 120000 index 2428b74b9..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py deleted file mode 120000 index b8b8ba432..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py deleted file mode 120000 index 92c3904af..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py deleted file mode 120000 index 2adf271c1..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py deleted file mode 100755 index bb1c093c8..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ /dev/null @@ -1,615 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding_method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 50 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - idx = 0 - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - audio: np.ndarray = cut.load_audio() - if audio.max() > 1 or audio.min() < -1: - audio = audio / max(abs(audio.max()), abs(audio.min())) - print(audio) - print(audio.max()) - print(audio.min()) - print(cut) - idx += 1 - print(idx) - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - commonvoice = CommonVoiceAsrDataModule(args) - test_cuts = commonvoice.test_cuts() - test_sets = "test-cv" - - results_dict = decode_dataset( - cuts=test_cuts, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_sets, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py deleted file mode 100755 index 5400df804..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless7_streaming/test_model.py -""" - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - - -def test_model(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = "2,4,3,2,4" - params.feedforward_dims = "1024,1024,2048,2048,1024" - params.nhead = "8,8,8,8,8" - params.encoder_dims = "384,384,384,384,384" - params.attention_dims = "192,192,192,192,192" - params.encoder_unmasked_dims = "256,256,256,256,256" - params.zipformer_downsampling_factors = "1,2,4,8,2" - params.cnn_module_kernels = "31,31,31,31,31" - params.decoder_dim = 512 - params.joiner_dim = 512 - params.num_left_chunks = 4 - params.short_chunk_size = 50 - params.decode_chunk_len = 32 - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - # Test jit script - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - print("Using torch.jit.script") - model = torch.jit.script(model) - - -def test_model_jit_trace(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = "2,4,3,2,4" - params.feedforward_dims = "1024,1024,2048,2048,1024" - params.nhead = "8,8,8,8,8" - params.encoder_dims = "384,384,384,384,384" - params.attention_dims = "192,192,192,192,192" - params.encoder_unmasked_dims = "256,256,256,256,256" - params.zipformer_downsampling_factors = "1,2,4,8,2" - params.cnn_module_kernels = "31,31,31,31,31" - params.decoder_dim = 512 - params.joiner_dim = 512 - params.num_left_chunks = 4 - params.short_chunk_size = 50 - params.decode_chunk_len = 32 - model = get_transducer_model(params) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - convert_scaled_to_non_scaled(model, inplace=True) - - # Test encoder - def _test_encoder(): - encoder = model.encoder - assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - encoder.decode_chunk_size, - params.decode_chunk_len, - ) - T = params.decode_chunk_len + 7 - - x = torch.zeros(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - states = encoder.get_init_state(device=x.device) - encoder.__class__.forward = encoder.__class__.streaming_forward - traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) - - states1 = encoder.get_init_state(device=x.device) - states2 = traced_encoder.get_init_state(device=x.device) - for i in range(5): - x = torch.randn(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) - y2, _, states2 = traced_encoder(x, x_lens, states2) - assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) - - # Test decoder - def _test_decoder(): - decoder = model.decoder - y = torch.zeros(10, decoder.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_decoder = torch.jit.trace(decoder, (y, need_pad)) - d1 = decoder(y, need_pad) - d2 = traced_decoder(y, need_pad) - assert torch.equal(d1, d2), (d1 - d2).abs().mean() - - # Test joiner - def _test_joiner(): - joiner = model.joiner - encoder_out_dim = joiner.encoder_proj.weight.shape[1] - decoder_out_dim = joiner.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) - j1 = joiner(encoder_out, decoder_out) - j2 = traced_joiner(encoder_out, decoder_out) - assert torch.equal(j1, j2), (j1 - j2).abs().mean() - - _test_encoder() - _test_decoder() - _test_joiner() - - -def main(): - test_model() - test_model_jit_trace() - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py deleted file mode 100755 index 67e1a8133..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ /dev/null @@ -1,1284 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/fr/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--use-validated-set", - type=str2bool, - default=False, - help="""Use the validated set for training. - This is useful when you want to use more data for training, - but not recommended for research purposes. - """, - ) - - parser.add_argument( - "--use-invalidated-set", - type=str2bool, - default=False, - help="""Use the invalidated set for training. - In case you want to take the risk and utilize more data for training. - """, - ) - - parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - commonvoice = CommonVoiceAsrDataModule(args) - - if not args.use_validated_set: - train_cuts = commonvoice.train_cuts() - else: - train_cuts = commonvoice.validated_cuts() - - if args.use_invalidated_set: - train_cuts += commonvoice.invalidated_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = commonvoice.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = commonvoice.dev_cuts() - valid_dl = commonvoice.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py deleted file mode 120000 index ec183baa7..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py deleted file mode 120000 index d301e1f9b..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/shared b/egs/commonvoice/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/commonvoice/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/asr_datamodule.py b/egs/commonvoice/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index c274de28a..000000000 --- a/egs/commonvoice/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/asr_datamodule.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/beam_search.py b/egs/commonvoice/ASR/zipformer/beam_search.py deleted file mode 120000 index 8e2c0a65c..000000000 --- a/egs/commonvoice/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/decode.py b/egs/commonvoice/ASR/zipformer/decode.py deleted file mode 100755 index 7fd6d0ccd..000000000 --- a/egs/commonvoice/ASR/zipformer/decode.py +++ /dev/null @@ -1,1052 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"-context-score-{params.context_score}" - return {prefix: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append((sp.encode(line.strip()), 0.0)) - context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - commonvoice = CommonVoiceAsrDataModule(args) - - test_cuts = commonvoice.test_cuts() - dev_cuts = commonvoice.dev_cuts() - - test_dl = commonvoice.test_dataloaders(test_cuts) - dev_dl = commonvoice.valid_dataloaders(dev_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/zipformer/decode_char.py b/egs/commonvoice/ASR/zipformer/decode_char.py deleted file mode 100755 index 1f8c9c7c6..000000000 --- a/egs/commonvoice/ASR/zipformer/decode_char.py +++ /dev/null @@ -1,813 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao -# Mingshuang Luo, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/zh-HK/lang_char \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/zh-HK/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (trivial_graph) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/zh-HK/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(4) fast beam search (LG) -./zipformer/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/zh-HK/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/zh-HK/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/zh-HK/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}_" + key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list("".join(text.split())) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest_oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - commonvoice = CommonVoiceAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - dev_cuts = commonvoice.dev_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = commonvoice.valid_dataloaders(dev_cuts) - - test_cuts = commonvoice.test_cuts() - test_cuts = test_cuts.filter(remove_short_utt) - test_dl = commonvoice.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/zipformer/decode_stream.py b/egs/commonvoice/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/commonvoice/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/decoder.py b/egs/commonvoice/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/commonvoice/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/encoder_interface.py b/egs/commonvoice/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/commonvoice/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py b/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py deleted file mode 120000 index f9d756352..000000000 --- a/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py deleted file mode 120000 index 652346001..000000000 --- a/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py b/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export-onnx.py b/egs/commonvoice/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/commonvoice/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/export.py b/egs/commonvoice/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/commonvoice/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/joiner.py b/egs/commonvoice/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/commonvoice/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/model.py b/egs/commonvoice/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/commonvoice/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/onnx_check.py b/egs/commonvoice/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/commonvoice/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/onnx_pretrained.py b/egs/commonvoice/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/commonvoice/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/optim.py b/egs/commonvoice/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/commonvoice/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/scaling.py b/egs/commonvoice/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/commonvoice/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/scaling_converter.py b/egs/commonvoice/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/commonvoice/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/streaming_beam_search.py b/egs/commonvoice/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/commonvoice/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/streaming_decode.py b/egs/commonvoice/ASR/zipformer/streaming_decode.py deleted file mode 100755 index 1d0230c76..000000000 --- a/egs/commonvoice/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,859 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import CommonVoiceAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - commonvoice = CommonVoiceAsrDataModule(args) - - test_cuts = commonvoice.test_cuts() - dev_cuts = commonvoice.dev_cuts() - - test_sets = ["test", "dev"] - test_cuts = [test_cuts, dev_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/zipformer/streaming_decode_char.py b/egs/commonvoice/ASR/zipformer/streaming_decode_char.py deleted file mode 100755 index 249cba9f5..000000000 --- a/egs/commonvoice/ASR/zipformer/streaming_decode_char.py +++ /dev/null @@ -1,861 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -from asr_datamodule import CommonVoiceAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/zh-HK/lang_char", - help="Path to the lang dir(containing lexicon, tokens, etc.)", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - commonvoice = CommonVoiceAsrDataModule(args) - - test_cuts = commonvoice.test_cuts() - dev_cuts = commonvoice.dev_cuts() - - test_sets = ["test", "dev"] - test_cuts = [test_cuts, dev_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/zipformer/subsampling.py b/egs/commonvoice/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/commonvoice/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py deleted file mode 100755 index 271014db0..000000000 --- a/egs/commonvoice/ASR/zipformer/train.py +++ /dev/null @@ -1,1412 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/en/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--use-validated-set", - type=str2bool, - default=False, - help="""Use the validated set for training. - This is useful when you want to use more data for training, - but not recommended for research purposes. - """, - ) - - parser.add_argument( - "--use-invalidated-set", - type=str2bool, - default=False, - help="""Use the invalidated set for training. - In case you want to take the risk and utilize more data for training. - """, - ) - - parser.add_argument( - "--base-lr", - type=float, - default=0.045, - help="The base learning rate.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - commonvoice = CommonVoiceAsrDataModule(args) - - if not args.use_validated_set: - train_cuts = commonvoice.train_cuts() - else: - train_cuts = commonvoice.validated_cuts() - - if args.use_invalidated_set: - train_cuts += commonvoice.invalidated_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = commonvoice.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - dev_cuts = commonvoice.dev_cuts() - dev_dl = commonvoice.valid_dataloaders(dev_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=dev_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py deleted file mode 100755 index 0aa7856cc..000000000 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ /dev/null @@ -1,1052 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from typing import Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CommonVoiceAsrDataModule -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from train import ( - add_model_arguments, - get_adjusted_batch_count, - get_model, - load_checkpoint_if_available, - save_checkpoint, - set_batch_count, -) - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/zh-HK/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--use-validated-set", - type=str2bool, - default=False, - help="""Use the validated set for training. - This is useful when you want to use more data for training, - but not recommended for research purposes. - """, - ) - - parser.add_argument( - "--use-invalidated-set", - type=str2bool, - default=False, - help="""Use the invalidated set for training. - In case you want to take the risk and utilize more data for training. - """, - ) - - parser.add_argument( - "--base-lr", - type=float, - default=0.045, - help="The base learning rate.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - commonvoice = CommonVoiceAsrDataModule(args) - - if not args.use_validated_set: - train_cuts = commonvoice.train_cuts() - else: - train_cuts = commonvoice.validated_cuts() - - if args.use_invalidated_set: - train_cuts += commonvoice.invalidated_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = commonvoice.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - dev_cuts = commonvoice.dev_cuts() - dev_dl = commonvoice.valid_dataloaders(dev_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=dev_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - CommonVoiceAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -# torch.set_num_threads(1) -# torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/zipformer/zipformer.py b/egs/commonvoice/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/commonvoice/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore deleted file mode 100644 index cd0e20c4c..000000000 --- a/egs/csj/ASR/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -librispeech_* -todelete* -lang* -notify_tg.py -finetune_* -misc.ini -.vscode/* -offline/* diff --git a/egs/csj/ASR/README.md b/egs/csj/ASR/README.md deleted file mode 100644 index 95c2ec6ac..000000000 --- a/egs/csj/ASR/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Introduction - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -These are the types of architectures currently available. - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|---------------------------------------------------| -| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming | diff --git a/egs/csj/ASR/RESULTS.md b/egs/csj/ASR/RESULTS.md deleted file mode 100644 index 56fdb899f..000000000 --- a/egs/csj/ASR/RESULTS.md +++ /dev/null @@ -1,200 +0,0 @@ -# Results - -## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) - -### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) - -See for more details. - -You can find a pretrained model, training logs, decoding logs, and decoding results at: - - -Number of model parameters: 75688409, i.e. 75.7M. - -#### training on disfluent transcript - -The CERs are: - -| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | -| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | -| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming | -| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise | -| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming | -| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise | -| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming | -| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise | -| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming | -| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise | -| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming | -| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise | -| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming | -| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise | - -Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, -while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. - -The training command was: -```bash -./pruned_transducer_stateless7_streaming/train.py \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ - --max-duration 375 \ - --transcript-mode disfluent \ - --lang data/lang_char \ - --manifest-dir /mnt/host/corpus/csj/fbank \ - --pad-feature 30 \ - --musan-dir /mnt/host/corpus/musan/musan/fbank -``` - -The simulated streaming decoding command was: -```bash -for chunk in 64 32; do - for m in greedy_search fast_beam_search modified_beam_search; do - python pruned_transducer_stateless7_streaming/decode.py \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ - --epoch 30 \ - --avg 17 \ - --max-duration 350 \ - --decoding-method $m \ - --manifest-dir /mnt/host/corpus/csj/fbank \ - --lang data/lang_char \ - --transcript-mode disfluent \ - --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \ - --decode-chunk-len $chunk \ - --pad-feature 30 \ - --gpu 0 - done -done -``` - -The streaming chunk-wise decoding command was: -```bash -for chunk in 64 32; do - for m in greedy_search fast_beam_search modified_beam_search; do - python pruned_transducer_stateless7_streaming/streaming_decode.py \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ - --epoch 30 \ - --avg 17 \ - --max-duration 350 \ - --decoding-method $m \ - --manifest-dir /mnt/host/corpus/csj/fbank \ - --lang data/lang_char \ - --transcript-mode disfluent \ - --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \ - --decode-chunk-len $chunk \ - --gpu 2 \ - --num-decode-streams 40 - done -done -``` - -#### training on fluent transcript - -The CERs are: - -| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | -| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | -| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming | -| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise | -| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming | -| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise | -| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming | -| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise | -| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming | -| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise | -| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming | -| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise | -| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming | -| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise | - -Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, -while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. - -The training command was: -```bash -./pruned_transducer_stateless7_streaming/train.py \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ - --max-duration 375 \ - --transcript-mode fluent \ - --lang data/lang_char \ - --manifest-dir /mnt/host/corpus/csj/fbank \ - --pad-feature 30 \ - --musan-dir /mnt/host/corpus/musan/musan/fbank -``` - -The simulated streaming decoding command was: -```bash -for chunk in 64 32; do - for m in greedy_search fast_beam_search modified_beam_search; do - python pruned_transducer_stateless7_streaming/decode.py \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ - --epoch 30 \ - --avg 12 \ - --max-duration 350 \ - --decoding-method $m \ - --manifest-dir /mnt/host/corpus/csj/fbank \ - --lang data/lang_char \ - --transcript-mode fluent \ - --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \ - --decode-chunk-len $chunk \ - --pad-feature 30 \ - --gpu 1 - done -done -``` - -The streaming chunk-wise decoding command was: -```bash -for chunk in 64 32; do - for m in greedy_search fast_beam_search modified_beam_search; do - python pruned_transducer_stateless7_streaming/streaming_decode.py \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ - --epoch 30 \ - --avg 12 \ - --max-duration 350 \ - --decoding-method $m \ - --manifest-dir /mnt/host/corpus/csj/fbank \ - --lang data/lang_char \ - --transcript-mode fluent \ - --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \ - --decode-chunk-len $chunk \ - --gpu 3 \ - --num-decode-streams 40 - done -done -``` - -#### Comparing disfluent to fluent - -$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$ - -This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared. - -| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode | -| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- | -| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming | -| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise | -| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming | -| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise | -| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming | -| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise | -| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming | -| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise | -| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming | -| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise | -| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming | -| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise | -| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | | diff --git a/egs/csj/ASR/local/add_transcript_mode.py b/egs/csj/ASR/local/add_transcript_mode.py deleted file mode 100644 index f6b4b2caf..000000000 --- a/egs/csj/ASR/local/add_transcript_mode.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse -import logging -from configparser import ConfigParser -from pathlib import Path -from typing import List - -from lhotse import CutSet, SupervisionSet -from lhotse.recipes.csj import CSJSDBParser - -ARGPARSE_DESCRIPTION = """ -This script adds transcript modes to an existing CutSet or SupervisionSet. -""" - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description=ARGPARSE_DESCRIPTION, - ) - parser.add_argument( - "-f", - "--fbank-dir", - type=Path, - help="Path to directory where manifests are stored.", - ) - parser.add_argument( - "-c", - "--config", - type=Path, - nargs="+", - help="Path to config file for transcript parsing.", - ) - return parser.parse_args() - - -def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]: - parsers = [] - for config_file in config_files: - config = ConfigParser() - config.optionxform = str - assert config.read(config_file), f"{config_file} could not be found." - decisions = {} - for k, v in config["DECISIONS"].items(): - try: - decisions[k] = int(v) - except ValueError: - decisions[k] = v - parsers.append( - (config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions)) - ) - return parsers - - -def main(): - args = get_args() - logging.basicConfig( - format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), - level=logging.INFO, - ) - parsers = get_CSJParsers(args.config) - config = ConfigParser() - config.optionxform = str - assert config.read(args.config), args.config - decisions = {} - for k, v in config["DECISIONS"].items(): - try: - decisions[k] = int(v) - except ValueError: - decisions[k] = v - - logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.") - - manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz") - assert manifests, f"No cuts to be found in {args.fbank_dir}" - - for manifest in manifests: - results = [] - logging.info(f"Adding transcript modes to {manifest.name} now.") - cutset = CutSet.from_file(manifest) - for cut in cutset: - for name, parser in parsers: - cut.supervisions[0].custom[name] = parser.parse( - cut.supervisions[0].custom["raw"] - ) - cut.supervisions[0].text = "" - results.append(cut) - results = CutSet.from_items(results) - res_file = manifest.as_posix() - manifest.replace(manifest.parent / ("bak." + manifest.name)) - results.to_file(res_file) - - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py deleted file mode 100644 index ce560025d..000000000 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import os -from pathlib import Path -from typing import List, Tuple - -import torch - -# fmt: off -from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527 - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - RecordingSet, - SupervisionSet, -) -from lhotse.recipes.csj import concat_csj_supervisions - -# fmt: on - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -RNG_SEED = 42 -# concat_params_train = [ -# {"gap": 1.0, "maxlen": 10.0}, -# {"gap": 1.5, "maxlen": 8.0}, -# {"gap": 1.0, "maxlen": 18.0}, -# ] - -concat_params = {"gap": 1.0, "maxlen": 10.0} - - -def make_cutset_blueprints( - manifest_dir: Path, -) -> List[Tuple[str, CutSet]]: - - cut_sets = [] - logging.info("Creating non-train cuts.") - - # Create eval datasets - for i in range(1, 4): - sps = sorted( - SupervisionSet.from_file( - manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" - ), - key=lambda x: x.id, - ) - - cut_set = CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / f"csj_recordings_eval{i}.jsonl.gz" - ), - supervisions=concat_csj_supervisions(sps, **concat_params), - ) - cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_sets.append((f"eval{i}", cut_set)) - - # Create excluded dataset - sps = sorted( - SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"), - key=lambda x: x.id, - ) - cut_set = CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "csj_recordings_excluded.jsonl.gz" - ), - supervisions=concat_csj_supervisions(sps, **concat_params), - ) - cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_sets.append(("excluded", cut_set)) - - # Create valid dataset - sps = sorted( - SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"), - key=lambda x: x.id, - ) - cut_set = CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "csj_recordings_valid.jsonl.gz" - ), - supervisions=concat_csj_supervisions(sps, **concat_params), - ) - cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_sets.append(("valid", cut_set)) - - logging.info("Creating train cuts.") - - # Create train dataset - sps = sorted( - SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz") - + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"), - key=lambda x: x.id, - ) - - recording = RecordingSet.from_file( - manifest_dir / "csj_recordings_core.jsonl.gz" - ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") - - train_set = CutSet.from_manifests( - recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params) - ).trim_to_supervisions(keep_overlapping=False) - train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) - - cut_sets.append(("train", train_set)) - - return cut_sets - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-m", "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "-f", "--fbank-dir", type=Path, help="Path to save fbank features" - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - extractor = Fbank(FbankConfig(num_mel_bins=80)) - num_jobs = min(16, os.cpu_count()) - - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - if (args.fbank_dir / ".done").exists(): - logging.info( - "Previous fbank computed for CSJ found. " - f"Delete {args.fbank_dir / '.done'} to allow recomputing fbank." - ) - return - else: - cut_sets = make_cutset_blueprints(args.manifest_dir) - for part, cut_set in cut_sets: - logging.info(f"Processing {part}") - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - num_jobs=num_jobs, - storage_path=(args.fbank_dir / f"feats_{part}").as_posix(), - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz") - - logging.info("All fbank computed for CSJ.") - (args.fbank_dir / ".done").touch() - - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py deleted file mode 100644 index c942df98e..000000000 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -ARGPARSE_DESCRIPTION = """ -This file computes fbank features of the musan dataset. - -""" - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=manifest_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = fbank_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{fbank_dir}/musan_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - ) - musan_cuts.to_file(musan_cuts_path) - - -def get_args(): - parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "-m", "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "-f", "--fbank-dir", type=Path, help="Path to save fbank features" - ) - - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan(args.manifest_dir, args.fbank_dir) diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini deleted file mode 100644 index 4f0a9ec0e..000000000 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ /dev/null @@ -1,79 +0,0 @@ -[CONSTANTS] -; # Name of this mode -MODE = disfluent - -[DECISIONS] -; # フィラー、感情表出系感動詞 -; # 0 to remain, 1 to delete -; # Example: '(F ぎょっ)' -F = 0 -; # 言い直し、いいよどみなどによる語断片 -; # 0 to remain, 1 to delete -; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = 0 -; # 助詞、助動詞、接辞の言い直し -; # 0 to remain, 1 to delete -; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = 0 -; # 聞き取りや語彙の判断に自信がない場合 -; # 0 to remain, 1 to delete -; # Example: (? 字数) の -; # If no option: empty string is returned regardless of output -; # Example: '(?) で' -? = 0 -; # タグ?で、値は複数の候補が想定される場合 -; # 0 for main guess with matching morph info, 1 for second guess -; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' -?, = 0 -; # 音や言葉に関するメタ的な引用 -; # 0 to remain, 1 to delete -; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' -M = 0 -; # 外国語や古語、方言など -; # 0 to remain, 1 to delete -; # Example: '(O ザッツファイン)' -O = 0 -; # 講演者の名前、差別語、誹謗中傷など -; # 0 to remain, 1 to delete -; # Example: '国語研の (R ××) です' -R = 0 -; # 非朗読対象発話(朗読における言い間違い等) -; # 0 to remain, 1 to delete -; # Example: '(X 実際は) 実際には' -X = 0 -; # アルファベットや算用数字、記号の表記 -; # 0 to use Japanese form, 1 to use alphabet form -; # Example: '(A シーディーアール;CD-R)' -A = 1 -; # タグAで、単語は算用数字の場合 -; # 0 to use Japanese form, 1 to use Arabic numerals -; # Example: (A 二千;2000) -A_num = 0 -; # 何らかの原因で漢字表記できなくなった場合 -; # 0 to use broken form, 1 to use orthodox form -; # Example: '(K たち (F えー) ばな;橘)' -K = 1 -; # 転訛、発音の怠けなど、一時的な発音エラー -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(W ギーツ;ギジュツ)' -W = 1 -; # 語の読みに関する知識レベルのいい間違い -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(B シブタイ;ジュータイ)' -B = 0 -; # 笑いながら発話 -; # 0 to remain, 1 to delete -; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' -笑 = 0 -; # 泣きながら発話 -; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' -泣 = 0 -; # 咳をしながら発話 -; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' -咳 = 0 -; # ささやき声や独り言などの小さな声 -; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' -L = 0 diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini deleted file mode 100644 index 5d033ed17..000000000 --- a/egs/csj/ASR/local/conf/fluent.ini +++ /dev/null @@ -1,79 +0,0 @@ -[CONSTANTS] -; # Name of this mode -MODE = fluent - -[DECISIONS] -; # フィラー、感情表出系感動詞 -; # 0 to remain, 1 to delete -; # Example: '(F ぎょっ)' -F = 1 -; # 言い直し、いいよどみなどによる語断片 -; # 0 to remain, 1 to delete -; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = 1 -; # 助詞、助動詞、接辞の言い直し -; # 0 to remain, 1 to delete -; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = 1 -; # 聞き取りや語彙の判断に自信がない場合 -; # 0 to remain, 1 to delete -; # Example: (? 字数) の -; # If no option: empty string is returned regardless of output -; # Example: '(?) で' -? = 0 -; # タグ?で、値は複数の候補が想定される場合 -; # 0 for main guess with matching morph info, 1 for second guess -; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' -?, = 0 -; # 音や言葉に関するメタ的な引用 -; # 0 to remain, 1 to delete -; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' -M = 0 -; # 外国語や古語、方言など -; # 0 to remain, 1 to delete -; # Example: '(O ザッツファイン)' -O = 0 -; # 講演者の名前、差別語、誹謗中傷など -; # 0 to remain, 1 to delete -; # Example: '国語研の (R ××) です' -R = 0 -; # 非朗読対象発話(朗読における言い間違い等) -; # 0 to remain, 1 to delete -; # Example: '(X 実際は) 実際には' -X = 0 -; # アルファベットや算用数字、記号の表記 -; # 0 to use Japanese form, 1 to use alphabet form -; # Example: '(A シーディーアール;CD-R)' -A = 1 -; # タグAで、単語は算用数字の場合 -; # 0 to use Japanese form, 1 to use Arabic numerals -; # Example: (A 二千;2000) -A_num = 0 -; # 何らかの原因で漢字表記できなくなった場合 -; # 0 to use broken form, 1 to use orthodox form -; # Example: '(K たち (F えー) ばな;橘)' -K = 1 -; # 転訛、発音の怠けなど、一時的な発音エラー -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(W ギーツ;ギジュツ)' -W = 1 -; # 語の読みに関する知識レベルのいい間違い -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(B シブタイ;ジュータイ)' -B = 0 -; # 笑いながら発話 -; # 0 to remain, 1 to delete -; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' -笑 = 0 -; # 泣きながら発話 -; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' -泣 = 0 -; # 咳をしながら発話 -; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' -咳 = 0 -; # ささやき声や独り言などの小さな声 -; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' -L = 0 diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini deleted file mode 100644 index 3ada9aa24..000000000 --- a/egs/csj/ASR/local/conf/number.ini +++ /dev/null @@ -1,79 +0,0 @@ -[CONSTANTS] -; # Name of this mode -MODE = number - -[DECISIONS] -; # フィラー、感情表出系感動詞 -; # 0 to remain, 1 to delete -; # Example: '(F ぎょっ)' -F = 1 -; # 言い直し、いいよどみなどによる語断片 -; # 0 to remain, 1 to delete -; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = 1 -; # 助詞、助動詞、接辞の言い直し -; # 0 to remain, 1 to delete -; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = 1 -; # 聞き取りや語彙の判断に自信がない場合 -; # 0 to remain, 1 to delete -; # Example: (? 字数) の -; # If no option: empty string is returned regardless of output -; # Example: '(?) で' -? = 0 -; # タグ?で、値は複数の候補が想定される場合 -; # 0 for main guess with matching morph info, 1 for second guess -; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' -?, = 0 -; # 音や言葉に関するメタ的な引用 -; # 0 to remain, 1 to delete -; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' -M = 0 -; # 外国語や古語、方言など -; # 0 to remain, 1 to delete -; # Example: '(O ザッツファイン)' -O = 0 -; # 講演者の名前、差別語、誹謗中傷など -; # 0 to remain, 1 to delete -; # Example: '国語研の (R ××) です' -R = 0 -; # 非朗読対象発話(朗読における言い間違い等) -; # 0 to remain, 1 to delete -; # Example: '(X 実際は) 実際には' -X = 0 -; # アルファベットや算用数字、記号の表記 -; # 0 to use Japanese form, 1 to use alphabet form -; # Example: '(A シーディーアール;CD-R)' -A = 1 -; # タグAで、単語は算用数字の場合 -; # 0 to use Japanese form, 1 to use Arabic numerals -; # Example: (A 二千;2000) -A_num = 1 -; # 何らかの原因で漢字表記できなくなった場合 -; # 0 to use broken form, 1 to use orthodox form -; # Example: '(K たち (F えー) ばな;橘)' -K = 1 -; # 転訛、発音の怠けなど、一時的な発音エラー -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(W ギーツ;ギジュツ)' -W = 1 -; # 語の読みに関する知識レベルのいい間違い -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(B シブタイ;ジュータイ)' -B = 0 -; # 笑いながら発話 -; # 0 to remain, 1 to delete -; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' -笑 = 0 -; # 泣きながら発話 -; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' -泣 = 0 -; # 咳をしながら発話 -; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' -咳 = 0 -; # ささやき声や独り言などの小さな声 -; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' -L = 0 diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini deleted file mode 100644 index dafd65c9a..000000000 --- a/egs/csj/ASR/local/conf/symbol.ini +++ /dev/null @@ -1,80 +0,0 @@ -[CONSTANTS] -; # Name of this mode -; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf -MODE = symbol - -[DECISIONS] -; # フィラー、感情表出系感動詞 -; # 0 to remain, 1 to delete -; # Example: '(F ぎょっ)' -F = "#", ["F"] -; # 言い直し、いいよどみなどによる語断片 -; # 0 to remain, 1 to delete -; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = "@", ["D"] -; # 助詞、助動詞、接辞の言い直し -; # 0 to remain, 1 to delete -; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = "@", ["D2"] -; # 聞き取りや語彙の判断に自信がない場合 -; # 0 to remain, 1 to delete -; # Example: (? 字数) の -; # If no option: empty string is returned regardless of output -; # Example: '(?) で' -? = 0 -; # タグ?で、値は複数の候補が想定される場合 -; # 0 for main guess with matching morph info, 1 for second guess -; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' -?, = 0 -; # 音や言葉に関するメタ的な引用 -; # 0 to remain, 1 to delete -; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' -M = 0 -; # 外国語や古語、方言など -; # 0 to remain, 1 to delete -; # Example: '(O ザッツファイン)' -O = 0 -; # 講演者の名前、差別語、誹謗中傷など -; # 0 to remain, 1 to delete -; # Example: '国語研の (R ××) です' -R = 0 -; # 非朗読対象発話(朗読における言い間違い等) -; # 0 to remain, 1 to delete -; # Example: '(X 実際は) 実際には' -X = 0 -; # アルファベットや算用数字、記号の表記 -; # 0 to use Japanese form, 1 to use alphabet form -; # Example: '(A シーディーアール;CD-R)' -A = 1 -; # タグAで、単語は算用数字の場合 -; # 0 to use Japanese form, 1 to use Arabic numerals -; # Example: (A 二千;2000) -A_num = 1 -; # 何らかの原因で漢字表記できなくなった場合 -; # 0 to use broken form, 1 to use orthodox form -; # Example: '(K たち (F えー) ばな;橘)' -K = 1 -; # 転訛、発音の怠けなど、一時的な発音エラー -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(W ギーツ;ギジュツ)' -W = 1 -; # 語の読みに関する知識レベルのいい間違い -; # 0 to use wrong form, 1 to use orthodox form -; # Example: '(B シブタイ;ジュータイ)' -B = 0 -; # 笑いながら発話 -; # 0 to remain, 1 to delete -; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' -笑 = 0 -; # 泣きながら発話 -; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' -泣 = 0 -; # 咳をしながら発話 -; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' -咳 = 0 -; # ささやき声や独り言などの小さな声 -; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' -L = 0 diff --git a/egs/csj/ASR/local/disfluent_recogs_to_fluent.py b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py deleted file mode 100644 index 45c9c7656..000000000 --- a/egs/csj/ASR/local/disfluent_recogs_to_fluent.py +++ /dev/null @@ -1,202 +0,0 @@ -import argparse -from pathlib import Path - -import kaldialign -from lhotse import CutSet - -ARGPARSE_DESCRIPTION = """ -This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript, -compares it against a fluent transcript, and saves the results in a separate directory. -This is useful to compare disfluent models with fluent models on the same metric. - -""" - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description=ARGPARSE_DESCRIPTION, - ) - parser.add_argument( - "--recogs", - type=Path, - required=True, - help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.", - ) - parser.add_argument( - "--cut", - type=Path, - required=True, - help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.", - ) - parser.add_argument( - "--res-dir", type=Path, required=True, help="Path to save results" - ) - return parser.parse_args() - - -def d2f(stats): - """ - Compare the outputs of a disfluent model against a fluent reference. - Indicates a disfluent model's performance only on the content words - - CER^d_f = (sub_f + ins + del_f) / Nf - - """ - return stats["base"] / stats["Nf"] - - -def calc_cer(refs, hyps): - subs = { - "F": 0, - "D": 0, - } - ins = 0 - dels = { - "F": 0, - "D": 0, - } - cors = { - "F": 0, - "D": 0, - } - dis_ref_len = 0 - flu_ref_len = 0 - - for ref, hyp in zip(refs, hyps): - assert ( - ref[0] == hyp[0] - ), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}." - tag = ref[2].copy() - ref = ref[1] - dis_ref_len += len(ref) - # Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively. - flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)]) - hyp = hyp[1] - ali = kaldialign.align(ref, hyp, "*") - tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali] - for tag, (ref_word, hyp_word) in zip(tags, ali): - if "D" in tag or "F" in tag: - tag = "D" - else: - tag = "F" - - if ref_word == "*": - ins += 1 - elif hyp_word == "*": - dels[tag] += 1 - elif ref_word != hyp_word: - subs[tag] += 1 - else: - cors[tag] += 1 - - return { - "subs": subs, - "ins": ins, - "dels": dels, - "cors": cors, - "dis_ref_len": dis_ref_len, - "flu_ref_len": flu_ref_len, - } - - -def for_each_recogs(recogs_file: Path, refs, out_dir): - hyps = [] - with recogs_file.open() as fin: - for line in fin: - if "ref" in line: - continue - cutid, hyp = line.split(":\thyp=") - hyps.append((cutid, eval(hyp))) - - assert len(refs) == len( - hyps - ), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal." - stats = calc_cer(refs, hyps) - stat_table = ["tag,yes,no"] - - for cer_type in ["subs", "dels", "cors", "ins"]: - ret = f"{cer_type}" - for df in ["D", "F"]: - try: - ret += f",{stats[cer_type][df]}" - except TypeError: - # insertions do not belong to F or D, and is not subscriptable. - ret += f",{stats[cer_type]}," - break - stat_table.append(ret) - stat_table = "\n".join(stat_table) - - stats = { - "subd": stats["subs"]["D"], - "deld": stats["dels"]["D"], - "cord": stats["cors"]["D"], - "Nf": stats["flu_ref_len"], - "base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"], - } - - cer = d2f(stats) - results = [ - f"{cer:.2%}", - f"Nf,{stats['Nf']}", - ] - results = "\n".join(results) - - with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout: - fout.write(results) - fout.write("\n\n") - fout.write(stat_table) - - -def main(): - args = get_args() - recogs_file: Path = args.recogs - assert ( - recogs_file.is_file() or recogs_file.is_dir() - ), f"recogs_file cannot be found at {recogs_file}." - - args.res_dir.mkdir(parents=True, exist_ok=True) - - if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"): - assert ( - "csj_cuts" in args.cut.name - ), f"Expected {args.cut} to be a cuts manifest." - - refs: CutSet = CutSet.from_file(args.cut) - refs = sorted( - [ - ( - e.id, - list(e.supervisions[0].custom["disfluent"]), - e.supervisions[0].custom["disfluent_tag"].split(","), - ) - for e in refs - ], - key=lambda x: x[0], - ) - for_each_recogs(recogs_file, refs, args.res_dir) - - elif recogs_file.is_dir(): - recogs_file_path = recogs_file - for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]: - refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz") - refs = sorted( - [ - ( - r.id, - list(r.supervisions[0].custom["disfluent"]), - r.supervisions[0].custom["disfluent_tag"].split(","), - ) - for r in refs - ], - key=lambda x: x[0], - ) - for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"): - for_each_recogs(recogs_file, refs, args.res_dir) - - else: - raise TypeError(f"Unrecognised recogs file provided: {recogs_file}") - - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py deleted file mode 100644 index 924474d33..000000000 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,328 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -from pathlib import Path - -from lhotse import CutSet, load_manifest - -ARGPARSE_DESCRIPTION = """ -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in -pruned_transducer_stateless5/train.py for usage. -""" - - -def get_parser(): - parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") - - return parser.parse_args() - - -def main(): - args = get_parser() - - for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]: - path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz" - cuts: CutSet = load_manifest(path) - - print("\n---------------------------------\n") - print(path.name + ":") - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -csj_cuts_eval1.jsonl.gz: -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1023 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:55:40 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.8 │ -├───────────────────────────┼──────────┤ -│ std │ 2.7 │ -├───────────────────────────┼──────────┤ -│ min │ 0.2 │ -├───────────────────────────┼──────────┤ -│ 25% │ 4.9 │ -├───────────────────────────┼──────────┤ -│ 50% │ 7.7 │ -├───────────────────────────┼──────────┤ -│ 75% │ 9.0 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ max │ 10.0 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1023 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 0 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1023 │ -╘═══════════════════════════╧══════════╛ -SUPERVISION custom fields: -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:55:40 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:55:40 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ - ---------------------------------- - -csj_cuts_eval2.jsonl.gz: -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1025 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:02:07 │ -├───────────────────────────┼──────────┤ -│ mean │ 7.1 │ -├───────────────────────────┼──────────┤ -│ std │ 2.5 │ -├───────────────────────────┼──────────┤ -│ min │ 0.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 5.9 │ -├───────────────────────────┼──────────┤ -│ 50% │ 7.9 │ -├───────────────────────────┼──────────┤ -│ 75% │ 9.1 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ max │ 10.0 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1025 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 0 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1025 │ -╘═══════════════════════════╧══════════╛ -SUPERVISION custom fields: -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:02:07 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:02:07 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ - ---------------------------------- - -csj_cuts_eval3.jsonl.gz: -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 865 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:26:44 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.0 │ -├───────────────────────────┼──────────┤ -│ std │ 3.0 │ -├───────────────────────────┼──────────┤ -│ min │ 0.3 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.3 │ -├───────────────────────────┼──────────┤ -│ 50% │ 6.8 │ -├───────────────────────────┼──────────┤ -│ 75% │ 8.7 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ max │ 10.0 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 865 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 0 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 865 │ -╘═══════════════════════════╧══════════╛ -SUPERVISION custom fields: -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:26:44 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:26:44 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ - ---------------------------------- - -csj_cuts_valid.jsonl.gz: -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 3743 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 06:40:15 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.4 │ -├───────────────────────────┼──────────┤ -│ std │ 3.0 │ -├───────────────────────────┼──────────┤ -│ min │ 0.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.9 │ -├───────────────────────────┼──────────┤ -│ 50% │ 7.4 │ -├───────────────────────────┼──────────┤ -│ 75% │ 9.0 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.1 │ -├───────────────────────────┼──────────┤ -│ max │ 11.8 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 3743 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 0 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 3743 │ -╘═══════════════════════════╧══════════╛ -SUPERVISION custom fields: -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 06:40:15 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 06:40:15 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ - ---------------------------------- - -csj_cuts_excluded.jsonl.gz: -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 980 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 00:56:06 │ -├───────────────────────────┼──────────┤ -│ mean │ 3.4 │ -├───────────────────────────┼──────────┤ -│ std │ 3.1 │ -├───────────────────────────┼──────────┤ -│ min │ 0.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 0.8 │ -├───────────────────────────┼──────────┤ -│ 50% │ 2.2 │ -├───────────────────────────┼──────────┤ -│ 75% │ 5.8 │ -├───────────────────────────┼──────────┤ -│ 99% │ 9.9 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 9.9 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ max │ 10.0 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 980 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 0 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 980 │ -╘═══════════════════════════╧══════════╛ -SUPERVISION custom fields: -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 00:56:06 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 00:56:06 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ - ---------------------------------- - -csj_cuts_train.jsonl.gz: -Cut statistics: -╒═══════════════════════════╤════════════╕ -│ Cuts count: │ 914151 │ -├───────────────────────────┼────────────┤ -│ Total duration (hh:mm:ss) │ 1695:29:43 │ -├───────────────────────────┼────────────┤ -│ mean │ 6.7 │ -├───────────────────────────┼────────────┤ -│ std │ 2.9 │ -├───────────────────────────┼────────────┤ -│ min │ 0.1 │ -├───────────────────────────┼────────────┤ -│ 25% │ 4.6 │ -├───────────────────────────┼────────────┤ -│ 50% │ 7.5 │ -├───────────────────────────┼────────────┤ -│ 75% │ 8.9 │ -├───────────────────────────┼────────────┤ -│ 99% │ 11.0 │ -├───────────────────────────┼────────────┤ -│ 99.5% │ 11.0 │ -├───────────────────────────┼────────────┤ -│ 99.9% │ 11.1 │ -├───────────────────────────┼────────────┤ -│ max │ 18.0 │ -├───────────────────────────┼────────────┤ -│ Recordings available: │ 914151 │ -├───────────────────────────┼────────────┤ -│ Features available: │ 0 │ -├───────────────────────────┼────────────┤ -│ Supervisions available: │ 914151 │ -╘═══════════════════════════╧════════════╛ -SUPERVISION custom fields: -Speech duration statistics: -╒══════════════════════════════╤════════════╤══════════════════════╕ -│ Total speech duration │ 1695:29:43 │ 100.00% of recording │ -├──────────────────────────────┼────────────┼──────────────────────┤ -│ Total speaking time duration │ 1695:29:43 │ 100.00% of recording │ -├──────────────────────────────┼────────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧════════════╧══════════════════════╛ -""" diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py deleted file mode 100644 index 58b197922..000000000 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet -from lhotse.recipes.csj import CSJSDBParser - -ARGPARSE_DESCRIPTION = """ -This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system. - -It outputs 3 files into the lang directory: -- tokens.txt: a list of tokens in the output set. -- lang_type: a file that contains the string "char" - -""" - - -def get_args(): - parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help=( - "Name of lang dir. " - "If not set, this will default to lang_char_{trans-mode}" - ), - ) - - return parser.parse_args() - - -def main(): - args = get_args() - logging.basicConfig( - format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), - level=logging.INFO, - ) - - sysdef_string = set(["", "", ""]) - - # Using disfluent parsing as fluent is a subset of disfluent - parser = CSJSDBParser() - - token_set = set() - logging.info(f"Creating vocabulary from {args.train_cut}.") - train_cut: CutSet = CutSet.from_file(args.train_cut) - for cut in train_cut: - if "_sp" in cut.id: - continue - - text: str = cut.supervisions[0].custom["raw"] - for w in parser.parse(text, sep=" ").split(" "): - token_set.update(w) - - token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] - args.lang_dir.mkdir(parents=True, exist_ok=True) - (args.lang_dir / "tokens.txt").write_text( - "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) - ) - - (args.lang_dir / "lang_type").write_text("char") - logging.info("Done.") - - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py deleted file mode 100644 index 7bf7bdef0..000000000 --- a/egs/csj/ASR/local/utils/asr_datamodule.py +++ /dev/null @@ -1,464 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset): - def __init__( - self, - *args, - transcript_mode: str = "", - return_cuts: bool = False, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.transcript_mode = transcript_mode - self.return_cuts = True - self._return_cuts = return_cuts - - def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: - batch = super().__getitem__(cuts) - - if self.transcript_mode: - batch["supervisions"]["text"] = [ - supervision.custom[self.transcript_mode] - for cut in batch["supervisions"]["cut"] - for supervision in cut.supervisions - ] - - if not self._return_cuts: - del batch["supervisions"]["cut"] - - return batch - - -class CSJAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--transcript-mode", - type=str, - default="", - help="Mode of transcript in supervision to use.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--musan-dir", type=Path, help="Path to directory with musan cuts. " - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = AsrVariableTranscriptDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - transcript_mode=self.args.transcript_mode, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = AsrVariableTranscriptDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - transcript_mode=self.args.transcript_mode, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = AsrVariableTranscriptDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - transcript_mode=self.args.transcript_mode, - ) - else: - validate = AsrVariableTranscriptDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - transcript_mode=self.args.transcript_mode, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - - test = AsrVariableTranscriptDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - transcript_mode=self.args.transcript_mode, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz") - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get valid cuts") - return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz") - - @lru_cache() - def excluded_cuts(self) -> CutSet: - logging.info("About to get excluded cuts") - return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz") - - @lru_cache() - def eval1_cuts(self) -> CutSet: - logging.info("About to get eval1 cuts") - return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz") - - @lru_cache() - def eval2_cuts(self) -> CutSet: - logging.info("About to get eval2 cuts") - return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz") - - @lru_cache() - def eval3_cuts(self) -> CutSet: - logging.info("About to get eval3 cuts") - return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz") diff --git a/egs/csj/ASR/local/utils/tokenizer.py b/egs/csj/ASR/local/utils/tokenizer.py deleted file mode 100644 index c9be72be1..000000000 --- a/egs/csj/ASR/local/utils/tokenizer.py +++ /dev/null @@ -1,253 +0,0 @@ -import argparse -from pathlib import Path -from typing import Callable, List, Union - -import sentencepiece as spm -from k2 import SymbolTable - - -class Tokenizer: - text2word: Callable[[str], List[str]] - - @staticmethod - def add_arguments(parser: argparse.ArgumentParser): - group = parser.add_argument_group(title="Lang related options") - - group.add_argument("--lang", type=Path, help="Path to lang directory.") - - group.add_argument( - "--lang-type", - type=str, - default=None, - help=( - "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " - "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" - ), - ) - - @staticmethod - def Load(lang_dir: Path, lang_type="", oov=""): - - if not lang_type: - assert (lang_dir / "lang_type").exists(), "lang_type not specified." - lang_type = (lang_dir / "lang_type").read_text().strip() - - tokenizer = None - - if lang_type == "bpe": - assert ( - lang_dir / "bpe.model" - ).exists(), f"No BPE .model could be found in {lang_dir}." - tokenizer = spm.SentencePieceProcessor() - tokenizer.Load(str(lang_dir / "bpe.model")) - elif lang_type == "char": - tokenizer = CharTokenizer(lang_dir, oov=oov) - else: - raise NotImplementedError(f"{lang_type} not supported at the moment.") - - return tokenizer - - load = Load - - def PieceToId(self, piece: str) -> int: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - piece_to_id = PieceToId - - def IdToPiece(self, id: int) -> str: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - id_to_piece = IdToPiece - - def GetPieceSize(self) -> int: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - get_piece_size = GetPieceSize - - def __len__(self) -> int: - return self.get_piece_size() - - def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def EncodeAsIds(self, input: str) -> List[int]: - return self.EncodeAsIdsBatch([input])[0] - - def EncodeAsPieces(self, input: str) -> List[str]: - return self.EncodeAsPiecesBatch([input])[0] - - def Encode( - self, input: Union[str, List[str]], out_type=int - ) -> Union[List, List[List]]: - if not input: - return [] - - if isinstance(input, list): - if out_type is int: - return self.EncodeAsIdsBatch(input) - if out_type is str: - return self.EncodeAsPiecesBatch(input) - - if out_type is int: - return self.EncodeAsIds(input) - if out_type is str: - return self.EncodeAsPieces(input) - - encode = Encode - - def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def DecodeIds(self, input: List[int]) -> str: - return self.DecodeIdsBatch([input])[0] - - def DecodePieces(self, input: List[str]) -> str: - return self.DecodePiecesBatch([input])[0] - - def Decode( - self, - input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], - ) -> Union[List[str], str]: - - if not input: - return "" - - if isinstance(input, int): - return self.id_to_piece(input) - elif isinstance(input, str): - raise TypeError( - "Unlike spm.SentencePieceProcessor, cannot decode from type str." - ) - - if isinstance(input[0], list): - if not input[0] or isinstance(input[0][0], int): - return self.DecodeIdsBatch(input) - - if isinstance(input[0][0], str): - return self.DecodePiecesBatch(input) - - if isinstance(input[0], int): - return self.DecodeIds(input) - if isinstance(input[0], str): - return self.DecodePieces(input) - - raise RuntimeError("Unknown input type") - - decode = Decode - - def SplitBatch(self, input: List[str]) -> List[List[str]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: - if isinstance(input, list): - return self.SplitBatch(input) - elif isinstance(input, str): - return self.SplitBatch([input])[0] - raise RuntimeError("Unknown input type") - - split = Split - - -class CharTokenizer(Tokenizer): - def __init__(self, lang_dir: Path, oov="", sep=""): - assert ( - lang_dir / "tokens.txt" - ).exists(), f"tokens.txt could not be found in {lang_dir}." - token_table = SymbolTable.from_file(lang_dir / "tokens.txt") - assert ( - "#0" not in token_table - ), "This tokenizer does not support disambig symbols." - self._id2sym = token_table._id2sym - self._sym2id = token_table._sym2id - self.oov = oov - self.oov_id = self._sym2id[oov] - self.sep = sep - if self.sep: - self.text2word = lambda x: x.split(self.sep) - else: - self.text2word = lambda x: list(x.replace(" ", "")) - - def piece_to_id(self, piece: str) -> int: - try: - return self._sym2id[piece] - except KeyError: - return self.oov_id - - def id_to_piece(self, id: int) -> str: - return self._id2sym[id] - - def get_piece_size(self) -> int: - return len(self._sym2id) - - def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: - return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] - - def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: - return [ - [i if i in self._sym2id else self.oov for i in self.text2word(text)] - for text in input - ] - - def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: - return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] - - def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: - return [self.sep.join(text) for text in input] - - def SplitBatch(self, input: List[str]) -> List[List[str]]: - return [self.text2word(text) for text in input] - - -def test_CharTokenizer(): - test_single_string = "こんにちは" - test_multiple_string = [ - "今日はいい天気ですよね", - "諏訪湖は綺麗でしょう", - "这在词表外", - "分かち 書き に し た 文章 です", - "", - ] - test_empty_string = "" - sp = Tokenizer.load(Path("lang_char"), "char", oov="") - splitter = sp.split - print(sp.encode(test_single_string, out_type=str)) - print(sp.encode(test_single_string, out_type=int)) - print(sp.encode(test_multiple_string, out_type=str)) - print(sp.encode(test_multiple_string, out_type=int)) - print(sp.encode(test_empty_string, out_type=str)) - print(sp.encode(test_empty_string, out_type=int)) - print(sp.decode(sp.encode(test_single_string, out_type=str))) - print(sp.decode(sp.encode(test_single_string, out_type=int))) - print(sp.decode(sp.encode(test_multiple_string, out_type=str))) - print(sp.decode(sp.encode(test_multiple_string, out_type=int))) - print(sp.decode(sp.encode(test_empty_string, out_type=str))) - print(sp.decode(sp.encode(test_empty_string, out_type=int))) - print(splitter(test_single_string)) - print(splitter(test_multiple_string)) - print(splitter(test_empty_string)) - - -if __name__ == "__main__": - test_CharTokenizer() diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py deleted file mode 100644 index 7f67c64b6..000000000 --- a/egs/csj/ASR/local/validate_manifest.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut -- Supervision time bounds are within cut time bounds - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def validate_one_supervision_per_cut(c: Cut): - if len(c.supervisions) != 1: - raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") - - -def validate_supervision_and_cut_time_bounds(c: Cut): - s = c.supervisions[0] - - # Removed because when the cuts were trimmed from supervisions, - # the start time of the supervision can be lesser than cut start time. - # https://github.com/lhotse-speech/lhotse/issues/813 - # if s.start < c.start: - # raise ValueError( - # f"{c.id}: Supervision start time {s.start} is less " - # f"than cut start time {c.start}" - # ) - - if s.end > c.end: - raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" - ) - - -def main(): - args = get_args() - - manifest = Path(args.manifest) - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest(manifest) - assert isinstance(cut_set, CutSet) - - for c in cut_set: - validate_one_supervision_per_cut(c) - validate_supervision_and_cut_time_bounds(c) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh deleted file mode 100755 index 52339bb35..000000000 --- a/egs/csj/ASR/prepare.sh +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env bash -# We assume the following directories are downloaded. -# -# - $csj_dir -# CSJ is assumed to be the USB-type directory, which should contain the following subdirectories:- -# - DATA (not used in this script) -# - DOC (not used in this script) -# - MODEL (not used in this script) -# - MORPH -# - LDB (not used in this script) -# - SUWDIC (not used in this script) -# - SDB -# - core -# - ... -# - noncore -# - ... -# - PLABEL (not used in this script) -# - SUMMARY (not used in this script) -# - TOOL (not used in this script) -# - WAV -# - core -# - ... -# - noncore -# - ... -# - XML (not used in this script) -# -# - $musan_dir -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# - music -# - noise -# - speech -# -# By default, this script produces the original transcript like kaldi and espnet. Optionally, you -# can add other transcript formats by supplying your own config files. A few examples of these -# config files can be found in local/conf. - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=8 -stage=-1 -stop_stage=100 - -csj_dir=/mnt/host/corpus/csj -musan_dir=/mnt/host/corpus/musan/musan -trans_dir=$csj_dir/transcript -csj_fbank_dir=/mnt/host/corpus/csj/fbank -musan_fbank_dir=$musan_dir/fbank -csj_manifest_dir=data/manifests -musan_manifest_dir=$musan_dir/manifests - -. shared/parse_options.sh || exit 1 - -mkdir -p data - -log() { - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare CSJ manifest" - if [ ! -e $csj_manifest_dir/.csj.done ]; then - lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16 - touch $csj_manifest_dir/.csj.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - mkdir -p $musan_manifest_dir - if [ ! -e $musan_manifest_dir/.musan.done ]; then - lhotse prepare musan $musan_dir $musan_manifest_dir - touch $musan_manifest_dir/.musan.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute CSJ fbank" - if [ ! -e $csj_fbank_dir/.csj-validated.done ]; then - python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \ - --fbank-dir $csj_fbank_dir - parts=( - eval1 - eval2 - eval3 - valid - excluded - train - ) - for part in ${parts[@]}; do - python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz - done - touch $csj_fbank_dir/.csj-validated.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare CSJ lang_char" - python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz - python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for musan" - mkdir -p $musan_fbank_dir - - if [ ! -e $musan_fbank_dir/.musan.done ]; then - python local/compute_fbank_musan.py --manifest-dir $musan_manifest_dir --fbank-dir $musan_fbank_dir - touch $musan_fbank_dir/.musan.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Show manifest statistics" - python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt - cat $csj_fbank_dir/manifest_statistics.txt -fi diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py deleted file mode 100644 index f5235207a..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging -from configparser import ConfigParser - -import requests - - -def escape_html(text: str): - """ - Escapes all html characters in text - :param str text: - :rtype: str - """ - return text.replace("&", "&").replace("<", "<").replace(">", ">") - - -class TelegramStreamIO(logging.Handler): - - API_ENDPOINT = "https://api.telegram.org" - MAX_MESSAGE_LEN = 4096 - formatter = logging.Formatter( - "%(asctime)s - %(levelname)s at %(funcName)s " - "(line %(lineno)s):\n\n%(message)s" - ) - - def __init__(self, tg_configfile: str): - super(TelegramStreamIO, self).__init__() - config = ConfigParser() - if not config.read(tg_configfile): - raise FileNotFoundError( - f"{tg_configfile} not found. " "Retry without --telegram-cred flag." - ) - config = config["TELEGRAM"] - token = config["token"] - self.chat_id = config["chat_id"] - self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage" - - @staticmethod - def setup_logger(params): - if not params.telegram_cred: - return - formatter = logging.Formatter( - f"{params.exp_dir.name} %(asctime)s \n%(message)s" - ) - tg = TelegramStreamIO(params.telegram_cred) - tg.setLevel(logging.WARN) - tg.setFormatter(formatter) - logging.getLogger("").addHandler(tg) - - def emit(self, record: logging.LogRecord): - """ - Emit a record. - Send the record to the Web server as a percent-encoded dictionary - """ - data = { - "chat_id": self.chat_id, - "text": self.format(self.mapLogRecord(record)), - "parse_mode": "HTML", - } - try: - requests.get(self.url, json=data) - # return response.json() - except Exception as e: - logging.error(f"Failed to send telegram message: {repr(e)}") - pass - - def mapLogRecord(self, record): - """ - Default implementation of mapping the log record into a dict - that is sent as the CGI data. Overwrite in your class. - Contributed by Franz Glasner. - """ - - for k, v in record.__dict__.items(): - if isinstance(v, str): - setattr(record, k, escape_html(v)) - return record diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py deleted file mode 120000 index a48591198..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py deleted file mode 120000 index d7349b0a3..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py deleted file mode 100755 index f5a1d750d..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py +++ /dev/null @@ -1,846 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --lang data/lang_char \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method beam_search \ - --lang data/lang_char \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method modified_beam_search \ - --lang data/lang_char \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --lang data/lang_char \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --lang data/lang_char \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --lang data/lang_char \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --lang data/lang_char \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import CSJAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from tokenizer import Tokenizer -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--gpu", - type=int, - default=0, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--res-dir", - type=Path, - default=None, - help="The path to save results.", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir. It should contain at least a word table.", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--decoding-graph", - type=str, - default="", - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=30, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.text2word(sp.decode(hyp))) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = sp.text2word(ref_text) - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - return test_set_wers - - -@torch.no_grad() -def main(): - parser = get_parser() - CSJAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - if not params.res_dir: - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", params.gpu) - - logging.info(f"Device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # and are defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - decoding_graph = None - word_table = None - - if params.decoding_graph: - decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) - ) - elif "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - csj_corpus = CSJAsrDataModule(args) - - for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: - results_dict = decode_dataset( - dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()), - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - tot_err = save_results( - params=params, - test_set_name=subdir, - results_dict=results_dict, - ) - with ( - params.res_dir - / ( - f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" - f"_{params.avg}_{params.epoch}.cer" - ) - ).open("w") as fout: - if len(tot_err) == 1: - fout.write(f"{tot_err[0][1]}") - else: - fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py deleted file mode 120000 index ca8fed319..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py deleted file mode 120000 index 1ce277aa6..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py deleted file mode 100755 index 6d256308c..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ /dev/null @@ -1,1294 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import math -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CSJAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer_for_ncnn_export_only import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -LOG_EPS = math.log(1e-10) - -try: - from TelegramStreamIO import TelegramStreamIO - - HAS_TELEGRAM = True -except ImportError: - HAS_TELEGRAM = False - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") - - parser.add_argument( - "--telegram-cred", - type=Path, - default=None, - help="Telegram credentials to report progress in telegram", - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=0, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - is_pnnx=True, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: # noqa - logging.error(e, exc_info=True) - display_and_save_batch(batch, params=params, sp=sp) - raise e - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: - logging.warning( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - else: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - if ( - HAS_TELEGRAM - and batch_idx % (params.valid_interval * 3) == 0 - and not rank - ): - log_mode = logging.warning - else: - log_mode = logging.info - log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") - log_mode( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, master_port=params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - if HAS_TELEGRAM and params.telegram_cred: - TelegramStreamIO.setup_logger(params) - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.info( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - csj_corpus = CSJAsrDataModule(args) - train_cuts = csj_corpus.train_cuts() - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = csj_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = csj_corpus.valid_cuts() - valid_dl = csj_corpus.valid_dataloaders(valid_cuts) - - if params.start_batch <= 0 and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: Tokenizer, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - raise RuntimeError("Please don't use this file directly!") - parser = get_parser() - CSJAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py deleted file mode 120000 index cb673b3eb..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py deleted file mode 100755 index 06a0fa96b..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ /dev/null @@ -1,369 +0,0 @@ -#!/usr/bin/env python3 - -""" -Please see -https://k2-fsa.github.io/icefall/model-export/export-ncnn.html -for more details about how to use this file. - -We use -https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 - -to demonstrate the usage of this file. - -1. Download the pre-trained model - -cd egs/csj/ASR - -repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp_fluent/pretrained.pt" - -cd exp_fluent -ln -s pretrained.pt epoch-99.pt -popd - -2. Export to ncnn - -./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --lang $repo/data/lang_char \ - --exp-dir $repo/exp_fluent/ \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - \ - --decode-chunk-len 32 \ - --num-left-chunks 4 \ - --num-encoder-layers "2,4,3,2,4" \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --nhead "8,8,8,8,8" \ - --encoder-dims "384,384,384,384,384" \ - --attention-dims "192,192,192,192,192" \ - --encoder-unmasked-dims "256,256,256,256,256" \ - --zipformer-downsampling-factors "1,2,4,8,2" \ - --cnn-module-kernels "31,31,31,31,31" \ - --decoder-dim 512 \ - --joiner-dim 512 - -cd $repo/exp_fluent - -pnnx encoder_jit_trace-pnnx.pt -pnnx decoder_jit_trace-pnnx.pt -pnnx joiner_jit_trace-pnnx.pt - -You can find converted models at -https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14 - -Please also have a look at -https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14/blob/main/export-for-ncnn-fluent.sh - -See ./streaming-ncnn-decode.py -and -https://github.com/k2-fsa/sherpa-ncnn -for usage. -""" - -import argparse -import logging -from pathlib import Path - -import torch -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model -from scaling_converter import convert_scaled_to_non_scaled -from tokenizer import Tokenizer - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: torch.nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - encoder_model.__class__.forward = encoder_model.__class__.streaming_forward - - decode_chunk_len = encoder_model.decode_chunk_size * 2 - pad_length = 7 - T = decode_chunk_len + pad_length # 32 + 7 = 39 - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - - x = torch.zeros(1, T, 80, dtype=torch.float32) - states = encoder_model.get_init_state() - - traced_model = torch.jit.trace(encoder_model, (x, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: torch.nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: torch.nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - - setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") - - logging.info(f"device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - assert params.blank_id == 0, params.blank_id - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) - - encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - logging.info("Using torch.jit.trace()") - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py deleted file mode 100644 index 2d45ecca3..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/csj/ASR - ./pruned_transducer_stateless7_streaming/decode.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang data/lang_char - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 - # You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent -""" - -import argparse -import logging -from pathlib import Path - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from tokenizer import Tokenizer -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py deleted file mode 100644 index ab7c8748a..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 - -""" -Usage: -# use -O to skip assertions and avoid some of the TracerWarnings -python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --epoch 30 \ - --avg 10 \ - --use-averaged-model=True \ - --decode-chunk-len 32 -""" - -import argparse -import logging -from pathlib import Path - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from tokenizer import Tokenizer -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import AttributeDict, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: torch.nn.Module, - encoder_filename: str, - params: AttributeDict, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - decode_chunk_len = params.decode_chunk_len # before subsampling - pad_length = 7 - s = f"decode_chunk_len: {decode_chunk_len}" - logging.info(s) - assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( - encoder_model.decode_chunk_size, - decode_chunk_len, - ) - - T = decode_chunk_len + pad_length - - x = torch.zeros(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - states = encoder_model.get_init_state(device=x.device) - - encoder_model.__class__.forward = encoder_model.__class__.streaming_forward - traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: torch.nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: torch.nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - - logging.info(f"device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / "encoder_jit_trace.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename, params) - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / "decoder_jit_trace.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / "joiner_jit_trace.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py deleted file mode 100644 index 58ee99e6a..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ /dev/null @@ -1,287 +0,0 @@ -#!/usr/bin/env python3 -# flake8: noqa -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models exported by `torch.jit.trace()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --epoch 30 \ - --avg 10 \ - --use-averaged-model=True \ - --decode-chunk-len 32 - -Usage of this script: - -./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ - --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ - --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ - --lang data/lang_char \ - --decode-chunk-len 32 \ - /path/to/foo.wav \ -""" - -import argparse -import logging -from typing import List, Optional - -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from tokenizer import Tokenizer - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder torchscript model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder torchscript model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner torchscript model. ", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -): - assert encoder_out.ndim == 2 - context_size = 2 - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) - # decoder_input.shape (1,, 1 context_size) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - else: - assert decoder_out.ndim == 2 - assert hyp is not None, hyp - - T = encoder_out.size(0) - for i in range(T): - cur_encoder_out = encoder_out[i : i + 1] - joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - - decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - - return hyp, decoder_out - - -def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - return OnlineFbank(opts) - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - - logging.info(f"device: {device}") - - encoder = torch.jit.load(args.encoder_model_filename) - decoder = torch.jit.load(args.decoder_model_filename) - joiner = torch.jit.load(args.joiner_model_filename) - - encoder.eval() - decoder.eval() - joiner.eval() - - encoder.to(device) - decoder.to(device) - joiner.to(device) - - sp = Tokenizer.load(args.lang, args.lang_type) - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor(args.sample_rate) - - logging.info(f"Reading sound files: {args.sound_file}") - wave_samples = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=args.sample_rate, - )[0] - logging.info(wave_samples.shape) - - logging.info("Decoding started") - chunk_length = args.decode_chunk_len - assert encoder.decode_chunk_size == chunk_length // 2, ( - encoder.decode_chunk_size, - chunk_length, - ) - - # we subsample features with ((x_len - 7) // 2 + 1) // 2 - pad_length = 7 - T = chunk_length + pad_length - - logging.info(f"chunk_length: {chunk_length}") - - states = encoder.get_init_state(device) - - tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) - - wave_samples = torch.cat([wave_samples, tail_padding]) - - chunk = int(0.25 * args.sample_rate) # 0.2 second - num_processed_frames = 0 - - hyp = None - decoder_out = None - - start = 0 - while start < wave_samples.numel(): - logging.info(f"{start}/{wave_samples.numel()}") - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - online_fbank.accept_waveform( - sampling_rate=args.sample_rate, - waveform=samples, - ) - while online_fbank.num_frames_ready - num_processed_frames >= T: - frames = [] - for i in range(T): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - frames = torch.cat(frames, dim=0).unsqueeze(0) - x_lens = torch.tensor([T], dtype=torch.int32) - encoder_out, out_lens, states = encoder( - x=frames, - x_lens=x_lens, - states=states, - ) - num_processed_frames += chunk_length - - hyp, decoder_out = greedy_search( - decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp - ) - - context_size = 2 - logging.info(args.sound_file) - logging.info(sp.decode(hyp[context_size:])) - - logging.info("Decoding Done") - - -torch.set_num_threads(4) -torch.set_num_interop_threads(1) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py deleted file mode 120000 index 482ebcfef..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py deleted file mode 120000 index 16c2bf28d..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py deleted file mode 120000 index 522bbaff9..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py deleted file mode 100644 index 66fbae378..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by -./pruned_transducer_stateless7_streaming/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from tokenizer import Tokenizer -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = Tokenizer.load(params.lang, params.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py deleted file mode 120000 index a7ef73bcb..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py deleted file mode 120000 index 566c317ff..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py deleted file mode 120000 index 92c3904af..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py deleted file mode 120000 index 2adf271c1..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py deleted file mode 100755 index 6a249dd3f..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ /dev/null @@ -1,603 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding_method greedy_search \ - --lang data/lang_char \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -import torch.nn as nn -from asr_datamodule import CSJAsrDataModule -from decode import save_results -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from tokenizer import Tokenizer -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import AttributeDict, setup_logger, str2bool - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--gpu", - type=int, - default=0, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--decoding-graph", - type=str, - default="", - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - parser.add_argument( - "--res-dir", - type=Path, - default=None, - help="The path to save results.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 50 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -@torch.no_grad() -def main(): - parser = get_parser() - CSJAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if not params.res_dir: - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", params.gpu) - - logging.info(f"Device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # and is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_graph: - decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) - ) - elif params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - args.return_cuts = True - csj_corpus = CSJAsrDataModule(args) - - for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: - results_dict = decode_dataset( - cuts=getattr(csj_corpus, f"{subdir}_cuts")(), - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - tot_err = save_results( - params=params, test_set_name=subdir, results_dict=results_dict - ) - - with ( - params.res_dir - / ( - f"{subdir}-{params.decode_chunk_len}" - f"_{params.avg}_{params.epoch}.cer" - ) - ).open("w") as fout: - if len(tot_err) == 1: - fout.write(f"{tot_err[0][1]}") - else: - fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py deleted file mode 100755 index 0a82ccfa4..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/csj/ASR - python ./pruned_transducer_stateless7_streaming/test_model.py -""" - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - - -def test_model(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = "2,4,3,2,4" - params.feedforward_dims = "1024,1024,2048,2048,1024" - params.nhead = "8,8,8,8,8" - params.encoder_dims = "384,384,384,384,384" - params.attention_dims = "192,192,192,192,192" - params.encoder_unmasked_dims = "256,256,256,256,256" - params.zipformer_downsampling_factors = "1,2,4,8,2" - params.cnn_module_kernels = "31,31,31,31,31" - params.decoder_dim = 512 - params.joiner_dim = 512 - params.num_left_chunks = 4 - params.short_chunk_size = 50 - params.decode_chunk_len = 32 - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - # Test jit script - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - print("Using torch.jit.script") - model = torch.jit.script(model) - - -def test_model_jit_trace(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = "2,4,3,2,4" - params.feedforward_dims = "1024,1024,2048,2048,1024" - params.nhead = "8,8,8,8,8" - params.encoder_dims = "384,384,384,384,384" - params.attention_dims = "192,192,192,192,192" - params.encoder_unmasked_dims = "256,256,256,256,256" - params.zipformer_downsampling_factors = "1,2,4,8,2" - params.cnn_module_kernels = "31,31,31,31,31" - params.decoder_dim = 512 - params.joiner_dim = 512 - params.num_left_chunks = 4 - params.short_chunk_size = 50 - params.decode_chunk_len = 32 - model = get_transducer_model(params) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - convert_scaled_to_non_scaled(model, inplace=True) - - # Test encoder - def _test_encoder(): - encoder = model.encoder - assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - encoder.decode_chunk_size, - params.decode_chunk_len, - ) - T = params.decode_chunk_len + 7 - - x = torch.zeros(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - states = encoder.get_init_state(device=x.device) - encoder.__class__.forward = encoder.__class__.streaming_forward - traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) - - states1 = encoder.get_init_state(device=x.device) - states2 = traced_encoder.get_init_state(device=x.device) - for i in range(5): - x = torch.randn(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) - y2, _, states2 = traced_encoder(x, x_lens, states2) - assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) - - # Test decoder - def _test_decoder(): - decoder = model.decoder - y = torch.zeros(10, decoder.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_decoder = torch.jit.trace(decoder, (y, need_pad)) - d1 = decoder(y, need_pad) - d2 = traced_decoder(y, need_pad) - assert torch.equal(d1, d2), (d1 - d2).abs().mean() - - # Test joiner - def _test_joiner(): - joiner = model.joiner - encoder_out_dim = joiner.encoder_proj.weight.shape[1] - decoder_out_dim = joiner.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) - j1 = joiner(encoder_out, decoder_out) - j2 = traced_joiner(encoder_out, decoder_out) - assert torch.equal(j1, j2), (j1 - j2).abs().mean() - - _test_encoder() - _test_decoder() - _test_joiner() - - -def main(): - test_model() - test_model_jit_trace() - - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py deleted file mode 120000 index 958c99e85..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py deleted file mode 100755 index ef7ea9013..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ /dev/null @@ -1,1292 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import math -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import CSJAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -LOG_EPS = math.log(1e-10) - -try: - from TelegramStreamIO import TelegramStreamIO - - HAS_TELEGRAM = True -except ImportError: - HAS_TELEGRAM = False - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") - - parser.add_argument( - "--telegram-cred", - type=Path, - default=None, - help="Telegram credentials to report progress in telegram", - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=0, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: # noqa - logging.error(e, exc_info=True) - display_and_save_batch(batch, params=params, sp=sp) - raise e - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: - logging.warning( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - else: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - if ( - HAS_TELEGRAM - and batch_idx % (params.valid_interval * 3) == 0 - and not rank - ): - log_mode = logging.warning - else: - log_mode = logging.info - log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") - log_mode( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, master_port=params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - if HAS_TELEGRAM and params.telegram_cred: - TelegramStreamIO.setup_logger(params) - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.info( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - csj_corpus = CSJAsrDataModule(args) - train_cuts = csj_corpus.train_cuts() - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = csj_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = csj_corpus.valid_cuts() - valid_dl = csj_corpus.valid_dataloaders(valid_cuts) - - if params.start_batch <= 0 and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: Tokenizer, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - CSJAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py deleted file mode 120000 index ec183baa7..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py deleted file mode 120000 index d301e1f9b..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/csj/ASR/shared b/egs/csj/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/csj/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/README.md b/egs/fluent_speech_commands/SLU/README.md deleted file mode 100755 index a203a9bfb..000000000 --- a/egs/fluent_speech_commands/SLU/README.md +++ /dev/null @@ -1,9 +0,0 @@ -## Fluent Speech Commands recipe - -This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances. - -Dataset Paper link: - -cd icefall/egs/fluent_speech_commands/ -Training: python transducer/train.py -Decoding: python transducer/decode.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/local/compile_hlg.py b/egs/fluent_speech_commands/SLU/local/compile_hlg.py deleted file mode 100755 index a7df8f966..000000000 --- a/egs/fluent_speech_commands/SLU/local/compile_hlg.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python3 - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str) -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - logging.info("Loading G.fst.txt") - with open(lang_dir / "G.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py deleted file mode 100755 index a51b7b47b..000000000 --- a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file computes fbank features of the Fluent Speech Commands dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or it wastes a -# lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_slu(manifest_dir, fbanks_dir): - src_dir = Path(manifest_dir) - output_dir = Path(fbanks_dir) - - # This dataset is rather small, so we use only one job - num_jobs = min(1, os.cpu_count()) - num_mel_bins = 23 - - dataset_parts = ( - "train", - "valid", - "test", - ) - prefix = "slu" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}" - if cuts_file.is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 1, # use one job - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(cuts_file) - - -parser = argparse.ArgumentParser() -parser.add_argument("manifest_dir") -parser.add_argument("fbanks_dir") - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - args = parser.parse_args() - - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_slu(args.manifest_dir, args.fbanks_dir) diff --git a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py deleted file mode 100755 index 6263e062f..000000000 --- a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py +++ /dev/null @@ -1,59 +0,0 @@ -import argparse - -import pandas -from tqdm import tqdm - - -def generate_lexicon(corpus_dir, lm_dir): - data = pandas.read_csv( - str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0 - ) - vocab_transcript = set() - vocab_frames = set() - transcripts = data["transcription"].tolist() - frames = list( - i - for i in zip( - data["action"].tolist(), data["object"].tolist(), data["location"].tolist() - ) - ) - - for transcript in tqdm(transcripts): - for word in transcript.split(): - vocab_transcript.add(word) - - for frame in tqdm(frames): - for word in frame: - vocab_frames.add("_".join(word.split())) - - with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file: - lexicon_transcript_file.write(" 1" + "\n") - lexicon_transcript_file.write(" 2" + "\n") - lexicon_transcript_file.write(" 0" + "\n") - id = 3 - for vocab in vocab_transcript: - lexicon_transcript_file.write(vocab + " " + str(id) + "\n") - id += 1 - - with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file: - lexicon_frames_file.write(" 1" + "\n") - lexicon_frames_file.write(" 2" + "\n") - lexicon_frames_file.write(" 0" + "\n") - id = 3 - for vocab in vocab_frames: - lexicon_frames_file.write(vocab + " " + str(id) + "\n") - id += 1 - - -parser = argparse.ArgumentParser() -parser.add_argument("corpus_dir") -parser.add_argument("lm_dir") - - -def main(): - args = parser.parse_args() - - generate_lexicon(args.corpus_dir, args.lm_dir) - - -main() diff --git a/egs/fluent_speech_commands/SLU/local/prepare_lang.py b/egs/fluent_speech_commands/SLU/local/prepare_lang.py deleted file mode 100755 index 2a71dcf81..000000000 --- a/egs/fluent_speech_commands/SLU/local/prepare_lang.py +++ /dev/null @@ -1,371 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon - -Lexicon = List[Tuple[str, List[str]]] - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "!SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - # assert token2id[""] == 0 - # assert word2id[""] == 0 - - eps = 0 - sil_token = word2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [word2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = word2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -parser = argparse.ArgumentParser() -parser.add_argument("lm_dir") - - -def main(): - args = parser.parse_args() - - out_dir = Path(args.lm_dir) - lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"] - names = ["frames", "transcript"] - sil_token = "!SIL" - sil_prob = 0.5 - - for name, lexicon_filename in zip(names, lexicon_filenames): - lexicon = read_lexicon(lexicon_filename) - tokens = get_words(lexicon) - words = get_words(lexicon) - new_lexicon = [] - for lexicon_item in lexicon: - new_lexicon.append((lexicon_item[0], [lexicon_item[0]])) - lexicon = new_lexicon - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - tokens = [""] + tokens - words = ["eps"] + words + ["#0", "!SIL"] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id) - write_mapping(out_dir / ("words_" + name + ".txt"), word2id) - write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=word2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=word2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt")) - torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt")) - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") - - -main() diff --git a/egs/fluent_speech_commands/SLU/prepare.sh b/egs/fluent_speech_commands/SLU/prepare.sh deleted file mode 100755 index 3ff339d91..000000000 --- a/egs/fluent_speech_commands/SLU/prepare.sh +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=1 -stop_stage=5 - -data_dir=path/to/fluent/speech/commands -target_root_dir=data/ - -lang_dir=${target_root_dir}/lang_phone -lm_dir=${target_root_dir}/lm -manifest_dir=${target_root_dir}/manifests -fbanks_dir=${target_root_dir}/fbanks - -. shared/parse_options.sh || exit 1 - -mkdir -p $lang_dir -mkdir -p $lm_dir - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "data_dir: $data_dir" - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare slu manifest" - mkdir -p $manifest_dir - lhotse prepare slu $data_dir $manifest_dir -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute fbank for SLU" - mkdir -p $fbanks_dir - python ./local/compute_fbank_slu.py $manifest_dir $fbanks_dir -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare lang" - # NOTE: " SIL" is added for implementation convenience - # as the graph compiler code requires that there is a OOV word - # in the lexicon. - python ./local/generate_lexicon.py $data_dir $lm_dir -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Train LM" - # We use a unigram G - ./shared/make_kn_lm.py \ - -ngram-order 1 \ - -text $lm_dir/words_transcript.txt \ - -lm $lm_dir/G_transcript.arpa - - ./shared/make_kn_lm.py \ - -ngram-order 1 \ - -text $lm_dir/words_frames.txt \ - -lm $lm_dir/G_frames.arpa - - python ./local/prepare_lang.py $lm_dir - - if [ ! -f $lm_dir/G_transcript.fst.txt ]; then - python -m kaldilm \ - --read-symbol-table="$lm_dir/words_transcript.txt" \ - $lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt - fi - - if [ ! -f $lm_dir/G_frames.fst.txt ]; then - python -m kaldilm \ - --read-symbol-table="$lm_dir/words_frames.txt" \ - $lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt - fi - - mkdir -p $lm_dir/frames - mkdir -p $lm_dir/transcript - - chmod -R +777 . - - for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt; - do - j=${i//"_frames"/} - mv "$lm_dir/$i" $lm_dir/frames/$j - done - - for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt; - do - j=${i//"_transcript"/} - mv "$lm_dir/$i" $lm_dir/transcript/$j - done -fi - - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compile HLG" - ./local/compile_hlg.py --lang-dir $lm_dir/frames - ./local/compile_hlg.py --lang-dir $lm_dir/transcript - -fi diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/fluent_speech_commands/SLU/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/__init__.py b/egs/fluent_speech_commands/SLU/transducer/__init__.py deleted file mode 100755 index e69de29bb..000000000 diff --git a/egs/fluent_speech_commands/SLU/transducer/beam_search.py b/egs/fluent_speech_commands/SLU/transducer/beam_search.py deleted file mode 100755 index a16aa0123..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/beam_search.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import torch -from transducer.model import Transducer - - -def greedy_search( - model: Transducer, encoder_out: torch.Tensor, id2word: dict -) -> List[str]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - hyp = [] - max_u = 1000 # terminate after this number of steps - u = 0 - - while t < T and u < max_u: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (N, 1, 1) - # TODO: Use logits.argmax() - y = log_prob.argmax() - if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out, (h, c) = model.decoder(y, (h, c)) - u += 1 - else: - t += 1 - # id2word = {1: "YES", 2: "NO"} - - hyp = [id2word[i] for i in hyp] - - return hyp diff --git a/egs/fluent_speech_commands/SLU/transducer/conformer.py b/egs/fluent_speech_commands/SLU/transducer/conformer.py deleted file mode 120000 index 8be0dc864..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/decode.py b/egs/fluent_speech_commands/SLU/transducer/decode.py deleted file mode 100755 index ba2b9aaea..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/decode.py +++ /dev/null @@ -1,346 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import torch -import torch.nn as nn -from transducer.beam_search import greedy_search -from transducer.conformer import Conformer -from transducer.decoder import Decoder -from transducer.joiner import Joiner -from transducer.model import Transducer -from transducer.slu_datamodule import SluDataModule - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_id2word(params): - id2word = {} - - # 0 is blank - id = 1 - try: - with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: - for line in lexicon_file: - if len(line.strip()) > 0: - id2word[id] = line.split()[0] - id += 1 - except: - pass - - return id2word - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=6, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - parser.add_argument( - "--exp-dir", - type=str, - default="transducer/exp", - help="Directory from which to load the checkpoints", - ) - parser.add_argument("--lang-dir", type=str, default="data/lm/frames") - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - "lang_dir": Path("data/lm/frames"), - # encoder/decoder params - "vocab_size": 3, # blank, yes, no - "blank_id": 0, - "embedding_dim": 32, - "hidden_dim": 16, - "num_decoder_layers": 4, - } - ) - - vocab_size = 1 - with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file: - for line in lexicon_file: - if ( - len(line.strip()) > 0 - ): # and '' not in line and '' not in line and '' not in line: - vocab_size += 1 - params.vocab_size = vocab_size - - return params - - -def decode_one_batch( - params: AttributeDict, model: nn.Module, batch: dict, id2word: dict -) -> List[List[int]]: - """Decode one batch and return the result in a list-of-list. - Each sub list contains the word IDs for an utterance in the batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding. - - params.method is "nbest", it uses nbest decoding. - - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) - Returns: - Return the decoding result. `len(ans)` == batch size. - """ - device = model.device - feature = batch["inputs"] - feature = feature.to(device) - # at entry, feature is (N, T, C) - feature_lens = batch["supervisions"]["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - hyp = greedy_search(model=model, encoder_out=encoder_out_i, id2word=id2word) - hyps.append(hyp) - - # hyps = [[word_table[i] for i in ids] for ids in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, -) -> List[Tuple[List[int], List[int]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - Returns: - Return a tuple contains two elements (ref_text, hyp_text): - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - id2word = get_id2word(params) - - results = [] - for batch_idx, batch in enumerate(dl): - texts = [ - " ".join(a.supervisions[0].custom["frames"]) - for a in batch["supervisions"]["cut"] - ] - texts = [ - " " + a.replace("change language", "change_language") + " " - for a in texts - ] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps = decode_one_batch( - params=params, model=model, batch=batch, id2word=id2word - ) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - exp_dir: Path, - test_set_name: str, - results: List[Tuple[List[int], List[int]]], -) -> None: - """Save results to `exp_dir`. - Args: - exp_dir: - The output directory. This function create the following files inside - this directory: - - - recogs-{test_set_name}.text - - It contains the reference and hypothesis results, like below:: - - ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] - hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] - ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] - hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] - - - errs-{test_set_name}.txt - - It contains the detailed WER. - test_set_name: - The name of the test set, which will be part of the result filename. - results: - A list of tuples, each of which contains (ref_words, hyp_words). - Returns: - Return None. - """ - recog_path = exp_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = exp_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - write_error_stats(f, f"{test_set_name}", results) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - -def get_transducer_model(params: AttributeDict): - # encoder = Tdnn( - # num_features=params.feature_dim, - # output_dim=params.hidden_dim, - # ) - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.hidden_dim, - ) - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.hidden_dim, - embedding_dropout=0.4, - rnn_dropout=0.4, - ) - joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) - return transducer - - -@torch.no_grad() -def main(): - parser = get_parser() - SluDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - params["env_info"] = get_env_info() - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = get_transducer_model(params) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - model.to(device) - model.eval() - model.device = device - - # we need cut ids to display recognition results. - args.return_cuts = True - slu = SluDataModule(args) - test_dl = slu.test_dataloaders() - results = decode_dataset( - dl=test_dl, - params=params, - model=model, - ) - - test_set_name = str(args.feature_dir).split("/")[-2] - save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/fluent_speech_commands/SLU/transducer/decoder.py b/egs/fluent_speech_commands/SLU/transducer/decoder.py deleted file mode 120000 index e99310f91..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../yesno/ASR/transducer/decoder.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/joiner.py b/egs/fluent_speech_commands/SLU/transducer/joiner.py deleted file mode 120000 index 75fa64868..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer/joiner.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/model.py b/egs/fluent_speech_commands/SLU/transducer/model.py deleted file mode 120000 index 10f6ddad1..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer/model.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py deleted file mode 100755 index fa715abdd..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import List - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool - - -class SluDataModule(DataModule): - """ - DataModule for k2 ASR experiments. - It assumes there is always one train dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--feature-dir", - type=Path, - default=Path("data/fbanks"), - help="Path to directory with train/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=30.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=False, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=10, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - def train_dataloaders(self) -> DataLoader: - logging.info("About to get train cuts") - cuts_train = self.train_cuts() - - logging.info("About to create train dataset") - transforms = [] - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - FbankConfig(sampling_rate=8000, num_mel_bins=23) - ), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - - return train_dl - - def valid_dataloaders(self) -> DataLoader: - logging.info("About to get valid cuts") - cuts_valid = self.valid_cuts() - - logging.debug("About to create valid dataset") - valid = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create valid dataloader") - valid_dl = DataLoader( - valid, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - return valid_dl - - def test_dataloaders(self) -> DataLoader: - logging.info("About to get test cuts") - cuts_test = self.test_cuts() - - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts_test, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.feature_dir / "slu_cuts_train.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> List[CutSet]: - logging.info("About to get valid cuts") - cuts_valid = load_manifest_lazy( - self.args.feature_dir / "slu_cuts_valid.jsonl.gz" - ) - return cuts_valid - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz") - return cuts_test diff --git a/egs/fluent_speech_commands/SLU/transducer/subsampling.py b/egs/fluent_speech_commands/SLU/transducer/subsampling.py deleted file mode 120000 index fd7ca8b30..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_conformer.py b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py deleted file mode 120000 index 3060dd70c..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/test_conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer/test_conformer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_decoder.py b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py deleted file mode 120000 index d1bc718ce..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/test_decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../yesno/ASR/transducer/test_decoder.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_joiner.py b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py deleted file mode 120000 index 248222a8a..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/test_joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer/test_joiner.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_transducer.py b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py deleted file mode 120000 index df104bad7..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/test_transducer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer/test_transducer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/train.py b/egs/fluent_speech_commands/SLU/transducer/train.py deleted file mode 100755 index a59c0b754..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/train.py +++ /dev/null @@ -1,625 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import List, Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from lhotse.utils import fix_random_seed -from slu_datamodule import SluDataModule -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from transducer.conformer import Conformer - -# from torch.utils.tensorboard import SummaryWriter -from transducer.decoder import Decoder -from transducer.joiner import Joiner -from transducer.model import Transducer - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_word2id(params): - word2id = {} - - # 0 is blank - id = 1 - with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: - for line in lexicon_file: - if len(line.strip()) > 0: - word2id[line.split()[0]] = id - id += 1 - - return word2id - - -def get_labels(texts: List[str], word2id) -> k2.RaggedTensor: - """ - Args: - texts: - A list of transcripts. - Returns: - Return a ragged tensor containing the corresponding word ID. - """ - # blank is 0 - word_ids = [] - for t in texts: - words = t.split() - ids = [word2id[w] for w in words] - word_ids.append(ids) - - return k2.RaggedTensor(word_ids) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=7, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer/exp", - help="Directory to save results", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument("--lang-dir", type=str, default="data/lm/frames") - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - """ - params = AttributeDict( - { - "lr": 1e-4, - "feature_dim": 23, - "weight_decay": 1e-6, - "start_epoch": 0, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 100, - "reset_interval": 20, - "valid_interval": 3000, - "exp_dir": Path("transducer/exp"), - "lang_dir": Path("data/lm/frames"), - # encoder/decoder params - "vocab_size": 3, # blank, yes, no - "blank_id": 0, - "embedding_dim": 32, - "hidden_dim": 16, - "num_decoder_layers": 4, - } - ) - - vocab_size = 1 - with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: - for line in lexicon_file: - if ( - len(line.strip()) > 0 - ): # and '' not in line and '' not in line and '' not in line: - vocab_size += 1 - params.vocab_size = vocab_size - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Tdnn in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - feature_lens = batch["supervisions"]["num_frames"].to(device) - - texts = [ - " ".join(a.supervisions[0].custom["frames"]) - for a in batch["supervisions"]["cut"] - ] - texts = [ - " " + a.replace("change language", "change_language") + " " - for a in texts - ] - labels = get_labels(texts, word2ids).to(device) - - with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=labels) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = feature.size(0) - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - valid_dl: torch.utils.data.DataLoader, - word2ids, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=False, - word2ids=word2ids, - ) - assert loss.requires_grad is False - - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - word2ids, - tb_writer: None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, model=model, batch=batch, is_training=True, word2ids=word2ids - ) - # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - word2ids=word2ids, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def get_transducer_model(params: AttributeDict): - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.hidden_dim, - ) - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.hidden_dim, - embedding_dropout=0.4, - rnn_dropout=0.4, - ) - joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) - - return transducer - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - - params.update(vars(args)) - params["env_info"] = get_env_info() - - word2ids = get_word2id(params) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - # if args.tensorboard and rank == 0: - # tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - # else: - # tb_writer = None - tb_writer = None - - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - else: - device = torch.device("cpu") - logging.info(f"device: {device}") - - model = get_transducer_model(params) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - model.device = device - - optimizer = optim.Adam( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - slu = SluDataModule(args) - train_dl = slu.train_dataloaders() - - # There are only 60 waves: 30 files are used for training - # and the remaining 30 files are used for testing. - # We use test data as validation. - valid_dl = slu.test_dataloaders() - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - word2ids=word2ids, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=None, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - SluDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/fluent_speech_commands/SLU/transducer/transformer.py b/egs/fluent_speech_commands/SLU/transducer/transformer.py deleted file mode 120000 index 214afed39..000000000 --- a/egs/fluent_speech_commands/SLU/transducer/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/transformer.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/.gitignore b/egs/gigaspeech/ASR/.gitignore deleted file mode 100644 index 8dec2d86d..000000000 --- a/egs/gigaspeech/ASR/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -log-* -.DS_Store \ No newline at end of file diff --git a/egs/gigaspeech/ASR/README.md b/egs/gigaspeech/ASR/README.md deleted file mode 100644 index f0d60898c..000000000 --- a/egs/gigaspeech/ASR/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# GigaSpeech -GigaSpeech, an evolving, multi-domain English -speech recognition corpus with 10,000 hours of high quality labeled -audio, collected from audiobooks, podcasts -and YouTube, covering both read and spontaneous speaking styles, -and a variety of topics, such as arts, science, sports, etc. More details can be found: https://github.com/SpeechColab/GigaSpeech - -## Download - -Apply for the download credentials and download the dataset by following https://github.com/SpeechColab/GigaSpeech#download. Then create a symlink -```bash -ln -sfv /path/to/GigaSpeech download/GigaSpeech -``` - -## Performance Record -| | Dev | Test | -|--------------------------------|-------|-------| -| `zipformer` | 10.25 | 10.38 | -| `conformer_ctc` | 10.47 | 10.58 | -| `pruned_transducer_stateless2` | 10.40 | 10.51 | - -See [RESULTS](/egs/gigaspeech/ASR/RESULTS.md) for details. diff --git a/egs/gigaspeech/ASR/RESULTS.md b/egs/gigaspeech/ASR/RESULTS.md deleted file mode 100644 index 841ebdcfa..000000000 --- a/egs/gigaspeech/ASR/RESULTS.md +++ /dev/null @@ -1,226 +0,0 @@ -## Results -### zipformer (zipformer + pruned stateless transducer) - -See for more details. - -[zipformer](./zipformer) - -- Non-streaming -- normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M - -You can find a pretrained model, training logs, decoding logs, and decoding results at: - - -The tensorboard log for training is available at - - -You can use to deploy it. - -| decoding method | test-clean | test-other | comment | -|----------------------|------------|------------|--------------------| -| greedy_search | 10.31 | 10.50 | --epoch 30 --avg 9 | -| modified_beam_search | 10.25 | 10.38 | --epoch 30 --avg 9 | -| fast_beam_search | 10.26 | 10.48 | --epoch 30 --avg 9 | - -The training command is: -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 0 \ - --subset XL \ - --max-duration 700 \ - --use-transducer 1 \ - --use-ctc 0 \ - --lr-epochs 1 \ - --master-port 12345 -``` - -The decoding command is: -```bash -export CUDA_VISIBLE_DEVICES=0 - -# greedy search -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 1000 \ - --decoding-method greedy_search - -# modified beam search -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 1000 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -# fast beam search (one best) -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 1000 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -``` - -### GigaSpeech BPE training results (Pruned Transducer 2) - -#### 2022-05-12 - -#### Conformer encoder + embedding decoder - -Conformer encoder + non-recurrent decoder. The encoder is a -reworked version of the conformer encoder, with many changes. The -decoder contains only an embedding layer, a Conv1d (with kernel -size 2) and a linear layer (to transform tensor dim). k2 pruned -RNN-T loss is used. - -The best WER, as of 2022-05-12, for the gigaspeech is below - -Results are: - -| | Dev | Test | -|----------------------|-------|-------| -| greedy search | 10.51 | 10.73 | -| fast beam search | 10.50 | 10.69 | -| modified beam search | 10.40 | 10.51 | - -To reproduce the above result, use the following commands for training: - -```bash -cd egs/gigaspeech/ASR -./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" -./pruned_transducer_stateless2/train.py \ - --max-duration 120 \ - --num-workers 1 \ - --world-size 8 \ - --exp-dir pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --use-fp16 True -``` - -and the following commands for decoding: - -```bash -# greedy search -./pruned_transducer_stateless2/decode.py \ - --iter 3488000 \ - --avg 20 \ - --decoding-method greedy_search \ - --exp-dir pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --max-duration 600 - -# fast beam search -./pruned_transducer_stateless2/decode.py \ - --iter 3488000 \ - --avg 20 \ - --decoding-method fast_beam_search \ - --exp-dir pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --max-duration 600 - -# modified beam search -./pruned_transducer_stateless2/decode.py \ - --iter 3488000 \ - --avg 15 \ - --decoding-method modified_beam_search \ - --exp-dir pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --max-duration 600 -``` - -Pretrained model is available at - - -The tensorboard log for training is available at - - -### GigaSpeech BPE training results (Conformer-CTC) - -#### 2022-04-06 - -The best WER, as of 2022-04-06, for the gigaspeech is below - -Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: - -| | Dev | Test | -|-----|-------|-------| -| WER | 10.47 | 10.58 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.5 | 1.3 | - - -To reproduce the above result, use the following commands for training: - -```bash -cd egs/gigaspeech/ASR -./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" -./conformer_ctc/train.py \ - --max-duration 120 \ - --num-workers 1 \ - --world-size 8 \ - --exp-dir conformer_ctc/exp_500 \ - --lang-dir data/lang_bpe_500 -``` - -and the following command for decoding: - -```bash -./conformer_ctc/decode.py \ - --epoch 18 \ - --avg 6 \ - --method attention-decoder \ - --num-paths 1000 \ - --exp-dir conformer_ctc/exp_500 \ - --lang-dir data/lang_bpe_500 \ - --max-duration 20 \ - --num-workers 1 -``` - -Results using HLG decoding + whole lattice rescoring: - -| | Dev | Test | -|-----|-------|-------| -| WER | 10.51 | 10.62 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: -| lm_scale | -|----------| -| 0.2 | - -To reproduce the above result, use the training commands above, and the following command for decoding: - -```bash -./conformer_ctc/decode.py \ - --epoch 18 \ - --avg 6 \ - --method whole-lattice-rescoring \ - --num-paths 1000 \ - --exp-dir conformer_ctc/exp_500 \ - --lang-dir data/lang_bpe_500 \ - --max-duration 20 \ - --num-workers 1 -``` -Note: the `whole-lattice-rescoring` method is about twice as fast as the `attention-decoder` method, with slightly worse WER. - -Pretrained model is available at - - -The tensorboard log for training is available at - diff --git a/egs/gigaspeech/ASR/conformer_ctc/__init__.py b/egs/gigaspeech/ASR/conformer_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py deleted file mode 100644 index 569978424..000000000 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class GigaSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", - ) - - # GigaSpeech specific arguments - group.add_argument( - "--subset", - type=str, - default="XL", - help="Select the GigaSpeech subset (XS|S|M|L|XL)", - ) - group.add_argument( - "--small-dev", - type=str2bool, - default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", - ) - - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info(f"About to get train_{self.args.subset} cuts") - path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" - cuts_train = CutSet.from_jsonl_lazy(path) - return cuts_train - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") - if self.args.small_dev: - return cuts_valid.subset(first=1000) - else: - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py deleted file mode 100644 index a1cfe6e75..000000000 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ /dev/null @@ -1,910 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import warnings -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask - - -class Conformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - """ - - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: Union[float, bool] = 0.1, - ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - use_conv_batchnorm = True - if isinstance(use_feat_batchnorm, float): - use_conv_batchnorm = False - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - normalize_before, - use_conv_batchnorm, - ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity - - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - - Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - - if self.normalize_before: - x = self.after_norm(x) - - return x, mask - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - use_conv_batchnorm: bool = False, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm - ) - - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block - - self.dropout = nn.Dropout(dropout) - - self.normalize_before = normalize_before - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) - if not self.normalize_before: - src = self.norm_ff_macaron(src) - - # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) - - # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout( - self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - ) - if not self.normalize_before: - src = self.norm_conv(src) - - # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) - - if self.normalize_before: - src = self.norm_final(src) - - return src - - -class ConformerEncoder(nn.TransformerEncoder): - r"""ConformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - output = src - - for mod in self.layers: - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - - """ - self.extend_pe(x) - x = x * self.xscale - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.weight, - self.in_proj.bias, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - use_batchnorm: bool = False, - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - self.use_batchnorm = use_batchnorm - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - if self.use_batchnorm: - self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x) - if self.use_batchnorm: - x = self.norm(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def identity(x): - return x diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py deleted file mode 100755 index d7035a1f8..000000000 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ /dev/null @@ -1,711 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) -# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from conformer import Conformer -from gigaspeech_scoring import asr_text_post_processing - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=0, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="attention-decoder", - help="""Decoding method. - Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - - (5) attention-decoder. Extract n paths from the LM rescored - lattice, the path with the highest score is the decoding result. - - (6) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=1000, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="The lang dir", - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "subsampling_factor": 4, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def post_processing( - results: List[Tuple[str, List[str], List[str]]], -) -> List[Tuple[str, List[str], List[str]]]: - new_results = [] - for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((key, new_ref, new_hyp)) - return new_results - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - - nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - elif params.method == "attention-decoder": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=None, - ) - # TODO: pass `lattice` instead of `rescored_lattice` to - # `rescore_with_attention_decoder` - - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - else: - assert False, f"Unsupported decoding method: {params.method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - word_table: - It is the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - if hyps_dict is not None: - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - else: - assert len(results) > 0, "It should not decode to empty in the first batch!" - this_batch = [] - hyp_words = [] - for cut_id, ref_text in zip(cut_ids, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - for lm_scale in results.keys(): - results[lm_scale].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - if params.method == "attention-decoder": - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = post_processing(results) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id - - if params.method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - gigaspeech = GigaSpeechAsrDataModule(args) - - dev_cuts = gigaspeech.dev_cuts() - test_cuts = gigaspeech.test_cuts() - - dev_dl = gigaspeech.test_dataloaders(dev_cuts) - test_dl = gigaspeech.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py deleted file mode 100755 index ef53b77f8..000000000 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Jiayu Du -# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import os - -conversational_filler = [ - "UH", - "UHH", - "UM", - "EH", - "MM", - "HM", - "AH", - "HUH", - "HA", - "ER", - "OOF", - "HEE", - "ACH", - "EEE", - "EW", -] -unk_tags = ["", ""] -gigaspeech_punctuations = [ - "", - "", - "", - "", -] -gigaspeech_garbage_utterance_tags = ["", "", "", ""] -non_scoring_words = ( - conversational_filler - + unk_tags - + gigaspeech_punctuations - + gigaspeech_garbage_utterance_tags -) - - -def asr_text_post_processing(text: str) -> str: - # 1. convert to uppercase - text = text.upper() - - # 2. remove hyphen - # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" - text = text.replace("-", " ") - - # 3. remove non-scoring words from evaluation - remaining_words = [] - for word in text.split(): - if word in non_scoring_words: - continue - remaining_words.append(word) - - return " ".join(remaining_words) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result via" - "SCTK's tool sclite" - ) - parser.add_argument( - "ref", - type=str, - help="sclite's standard transcription(trn) reference file", - ) - parser.add_argument( - "hyp", - type=str, - help="sclite's standard transcription(trn) hypothesis file", - ) - parser.add_argument( - "work_dir", - type=str, - help="working dir", - ) - args = parser.parse_args() - - if not os.path.isdir(args.work_dir): - os.mkdir(args.work_dir) - - REF = os.path.join(args.work_dir, "REF") - HYP = os.path.join(args.work_dir, "HYP") - RESULT = os.path.join(args.work_dir, "RESULT") - - for io in [(args.ref, REF), (args.hyp, HYP)]: - with open(io[0], "r", encoding="utf8") as fi: - with open(io[1], "w+", encoding="utf8") as fo: - for line in fi: - line = line.strip() - if line: - cols = line.split() - text = asr_text_post_processing(" ".join(cols[0:-1])) - uttid_field = cols[-1] - print(f"{text} {uttid_field}", file=fo) - - # GigaSpeech's uttid comforms to swb - os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py deleted file mode 100644 index 8e0f73d05..000000000 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__(self, idim: int, odim: int) -> None: - """ - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) - """ - assert idim >= 7 - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), - nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), - nn.ReLU(), - ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described in the following paper: - https://arxiv.org/pdf/1910.09799.pdf - - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. - - This uses 2 VGG blocks with 2 Conv2d layers each, - subsampling its input by a factor of 4 in the time dimensions. - - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) - """ - super().__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the - # 2nd convolution, so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - x = x.unsqueeze(1) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py deleted file mode 100755 index 4883d04d8..000000000 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ /dev/null @@ -1,731 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from conformer import Conformer -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - conformer_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.7, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - use_feat_batchnorm: Normalization for the input features, can be a - boolean indicating whether to do batch - normalization, or a float which means just scaling - the input features with this float value. - If given a float value, we will remove batchnorm - layer in `ConvolutionModule` as well. - - - attention_dim: Hidden dim for multi-head attention model. - - - head: Number of heads of multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - weight_decay: The weight_decay for the optimizer. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 500, - "reset_interval": 2000, - "valid_interval": 30000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, - "num_decoder_layers": 6, - # parameters for loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for Noam - "weight_decay": 1e-6, - "warm_step": 100000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: BpeCtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: BpeCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - GigaSpeech = GigaSpeechAsrDataModule(args) - - train_cuts = GigaSpeech.train_cuts() - train_dl = GigaSpeech.train_dataloaders(train_cuts) - - valid_cuts = GigaSpeech.dev_cuts() - valid_dl = GigaSpeech.valid_dataloaders(valid_cuts) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/local/__init__.py b/egs/gigaspeech/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/gigaspeech/ASR/local/compile_hlg.py b/egs/gigaspeech/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/gigaspeech/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py deleted file mode 100755 index 9e0df0989..000000000 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from pathlib import Path - -import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_gigaspeech(): - in_out_dir = Path("data/fbank") - # number of workers in dataloader - num_workers = 20 - - # number of seconds in a batch - batch_duration = 1000 - - subsets = ("L", "M", "S", "XS", "DEV", "TEST") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - - logging.info(f"device: {device}") - - for partition in subsets: - cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Computing features") - - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}", - num_workers=num_workers, - batch_duration=batch_duration, - overwrite=True, - ) - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - logging.info(f"Saved to {cuts_path}") - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_gigaspeech() - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py deleted file mode 100755 index 51cd59078..000000000 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from datetime import datetime -from pathlib import Path - -import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--num-workers", - type=int, - default=20, - help="Number of dataloading workers used for reading the audio.", - ) - parser.add_argument( - "--batch-duration", - type=float, - default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", - ) - - parser.add_argument( - "--num-splits", - type=int, - required=True, - help="The number of splits of the XL subset", - ) - - parser.add_argument( - "--start", - type=int, - default=0, - help="Process pieces starting from this number (inclusive).", - ) - - parser.add_argument( - "--stop", - type=int, - default=-1, - help="Stop processing pieces until this number (exclusive).", - ) - return parser - - -def compute_fbank_gigaspeech_splits(args): - num_splits = args.num_splits - output_dir = f"data/fbank/XL_split" - output_dir = Path(output_dir) - assert output_dir.exists(), f"{output_dir} does not exist!" - - num_digits = 8 # num_digits is fixed by lhotse split-lazy - - start = args.start - stop = args.stop - if stop < start: - stop = num_splits - - stop = min(stop, num_splits) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - logging.info(f"device: {device}") - - for i in range(start, stop): - idx = f"{i}".zfill(num_digits) - logging.info(f"Processing {idx}/{num_splits}") - - cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Computing features") - - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/gigaspeech_feats_{idx}", - num_workers=args.num_workers, - batch_duration=args.batch_duration, - overwrite=True, - ) - - logging.info("About to split cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - logging.info(f"Saved to {cuts_path}") - - -def main(): - now = datetime.now() - date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - - log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - log_filename = f"{log_filename}-{date_time}" - - logging.basicConfig( - filename=log_filename, - format=formatter, - level=logging.INFO, - filemode="w", - ) - - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(logging.Formatter(formatter)) - logging.getLogger("").addHandler(console) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - compute_fbank_gigaspeech_splits(args) - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_musan.py b/egs/gigaspeech/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/gigaspeech/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/generate_unique_lexicon.py b/egs/gigaspeech/ASR/local/generate_unique_lexicon.py deleted file mode 120000 index c0aea1403..000000000 --- a/egs/gigaspeech/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/prepare_lang.py b/egs/gigaspeech/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/gigaspeech/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/prepare_lang_bpe.py b/egs/gigaspeech/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/gigaspeech/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py deleted file mode 100755 index a31685211..000000000 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import re -from pathlib import Path - -from lhotse import CutSet, SupervisionSegment -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import str2bool - -# Similar text filtering and normalization procedure as in: -# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Whether to use speed perturbation.", - ) - - return parser.parse_args() - - -def normalize_text( - utt: str, - punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), - whitespace_pattern=re.compile(r"\s\s+"), -) -> str: - return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) - - -def has_no_oov( - sup: SupervisionSegment, - oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"), -) -> bool: - return oov_pattern.search(sup.text) is None - - -def preprocess_giga_speech(args): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - output_dir.mkdir(exist_ok=True) - - dataset_parts = ( - "DEV", - "TEST", - "XL", - "L", - "M", - "S", - "XS", - ) - - logging.info("Loading manifest (may take 4 minutes)") - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix="gigaspeech", - suffix="jsonl.gz", - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - for partition, m in manifests.items(): - logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" - if raw_cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping") - continue - - # Note this step makes the recipe different than LibriSpeech: - # We must filter out some utterances and remove punctuation - # to be consistent with Kaldi. - logging.info("Filtering OOV utterances from supervisions") - m["supervisions"] = m["supervisions"].filter(has_no_oov) - logging.info(f"Normalizing text in {partition}") - for sup in m["supervisions"]: - sup.text = normalize_text(sup.text) - - # Create long-recording cut manifests. - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - # Run data augmentation that needs to be done in the - # time domain. - if partition not in ["DEV", "TEST"]: - if args.perturb_speed: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - logging.info(f"Saving to {raw_cuts_path}") - cut_set.to_file(raw_cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - preprocess_giga_speech(args) - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/local/train_bpe_model.py b/egs/gigaspeech/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/gigaspeech/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh deleted file mode 100755 index 219197e13..000000000 --- a/egs/gigaspeech/ASR/prepare.sh +++ /dev/null @@ -1,347 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=0 -stop_stage=100 - -# Split XL subset to a number of pieces (about 2000) -# This is to avoid OOM during feature extraction. -num_per_split=50 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/GigaSpeech -# You can find audio, dict, GigaSpeech.json inside it. -# You can apply for the download credentials by following -# https://github.com/SpeechColab/GigaSpeech#download -# -# - $dl_dir/lm -# This directory contains the language model downloaded from -# https://huggingface.co/wgb14/gigaspeech_lm -# -# - 3gram_pruned_1e7.arpa.gz -# - 4gram.arpa.gz -# - lexicon.txt -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" - # We assume that you have installed the git-lfs, if not, you could install it - # using: `sudo apt-get install git-lfs && git-lfs install` - [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm - git clone https://huggingface.co/wgb14/gigaspeech_lm $dl_dir/lm - gunzip -c $dl_dir/lm/3gram_pruned_1e7.arpa.gz > $dl_dir/lm/3gram_pruned_1e7.arpa - gunzip -c $dl_dir/lm/4gram.arpa.gz > $dl_dir/lm/4gram.arpa -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - [ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech - - # If you have pre-downloaded it to /path/to/GigaSpeech, - # you can create a symlink - # - # ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech - # - if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then - # Check credentials. - if [ ! -f $dl_dir/password ]; then - echo -n "$0: Please apply for the download credentials by following" - echo -n "https://github.com/SpeechColab/GigaSpeech#download" - echo " and save it to $dl_dir/password." - exit 1; - fi - PASSWORD=`cat $dl_dir/password 2>/dev/null` - if [ -z "$PASSWORD" ]; then - echo "$0: Error, $dl_dir/password is empty." - exit 1; - fi - PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1` - if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then - echo "$0: Error, invalid $dl_dir/password." - exit 1; - fi - # Download XL, DEV and TEST sets by default. - lhotse download gigaspeech --subset XL \ - --subset L \ - --subset M \ - --subset S \ - --subset XS \ - --subset DEV \ - --subset TEST \ - --host tsinghua \ - $dl_dir/password $dl_dir/GigaSpeech - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare GigaSpeech manifest (may take 15 minutes)" - # We assume that you have downloaded the GigaSpeech corpus - # to $dl_dir/GigaSpeech - mkdir -p data/manifests - lhotse prepare gigaspeech --subset XL \ - --subset L \ - --subset M \ - --subset S \ - --subset XS \ - --subset DEV \ - --subset TEST \ - -j $nj \ - $dl_dir/GigaSpeech data/manifests -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "State 3: Preprocess GigaSpeech manifest" - if [ ! -f data/fbank/.preprocess_complete ]; then - python3 ./local/preprocess_gigaspeech.py - touch data/fbank/.preprocess_complete - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech." - python3 ./local/compute_fbank_gigaspeech.py -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Split XL subset into pieces (may take 30 minutes)" - split_dir=data/fbank/XL_split - if [ ! -f $split_dir/.split_completed ]; then - lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $num_per_split - touch $split_dir/.split_completed - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Compute features for XL" - num_splits=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL_raw.*.jsonl.gz" | wc -l) - python3 ./local/compute_fbank_gigaspeech_splits.py \ - --num-workers 20 \ - --batch-duration 600 \ - --num-splits $num_splits -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Combine features for XL (may take 3 hours)" - if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then - pieces=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL.*.jsonl.gz") - lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compute fbank for musan" - mkdir -p data/fbank - ./local/compute_fbank_musan.py -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Prepare transcript_words.txt and words.txt" - lang_dir=data/lang_phone - mkdir -p $lang_dir - if [ ! -f $lang_dir/transcript_words.txt ]; then - gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \ - | jq '.text' \ - | sed 's/"//g' \ - > $lang_dir/transcript_words.txt - - # Delete utterances with garbage meta tags - garbage_utterance_tags=" " - for tag in $garbage_utterance_tags; do - sed -i "/${tag}/d" $lang_dir/transcript_words.txt - done - - # Delete punctuations in utterances - punctuation_tags=" " - for tag in $punctuation_tags; do - sed -i "s/${tag}//g" $lang_dir/transcript_words.txt - done - - # Ensure space only appears once - sed -i 's/\t/ /g' $lang_dir/transcript_words.txt - sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt - fi - - cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_dir/words.txt - (echo '!SIL'; echo ''; echo ''; ) | - cat - $lang_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_dir/words || exit 1; - mv $lang_dir/words $lang_dir/words.txt -fi - -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Prepare phone based lang" - lang_dir=data/lang_phone - mkdir -p $lang_dir - - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/lm/lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/{words.txt,transcript_words.txt} $lang_dir - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - fi - done -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Prepare bigram P" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - > $lang_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/P.arpa - fi - - if [ ! -f $lang_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_dir/P.arpa > $lang_dir/P.fst.txt - fi - done -fi - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $dl_dir/lm/3gram_pruned_1e7.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $dl_dir/lm/4gram.arpa > data/lm/G_4_gram.fst.txt - fi -fi - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - done -fi diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py deleted file mode 100644 index 40339365c..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ /dev/null @@ -1,412 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class GigaSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", - ) - - # GigaSpeech specific arguments - group.add_argument( - "--subset", - type=str, - default="XL", - help="Select the GigaSpeech subset (XS|S|M|L|XL)", - ) - group.add_argument( - "--small-dev", - type=str2bool, - default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info(f"About to get train_{self.args.subset} cuts") - path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" - cuts_train = CutSet.from_jsonl_lazy(path) - return cuts_train - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" - ) - if self.args.small_dev: - return cuts_valid.subset(first=1000) - else: - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" - ) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py deleted file mode 100755 index 76306fc4c..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corp. (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -./pruned_transducer_stateless7/compute_ppl.py \ - --ngram-lm-path ./download/lm/3gram_pruned_1e7.arpa - -""" - - -import argparse -import logging -import math -from typing import Dict, List, Optional, Tuple - -import kenlm -import torch -from asr_datamodule import GigaSpeechAsrDataModule - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--ngram-lm-path", - type=str, - default="download/lm/3gram_pruned_1e7.arpa", - help="The lang dir containing word table and LG graph", - ) - - return parser - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: kenlm.Model, -) -> Dict[str, float]: - """ - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - A ngram lm of kenlm.Model object. - Returns: - Return the perplexity of the giving dataset. - """ - sum_score_log = 0 - sum_n = 0 - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - for text in texts: - sum_n += len(text.split()) + 1 - sum_score_log += -1 * model.score(text) - - ppl = math.pow(10.0, sum_score_log / sum_n) - - return ppl - - -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - logging.info("About to load ngram LM") - model = kenlm.Model(args.ngram_lm_path) - - gigaspeech = GigaSpeechAsrDataModule(args) - - dev_cuts = gigaspeech.dev_cuts() - test_cuts = gigaspeech.test_cuts() - - dev_dl = gigaspeech.test_dataloaders(dev_cuts) - test_dl = gigaspeech.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - ppl = decode_dataset( - dl=test_dl, - model=model, - ) - logging.info(f"{test_set} PPL: {ppl}") - - logging.info("Done!") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/conformer.py deleted file mode 120000 index a65957180..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py deleted file mode 100755 index f1efebcb9..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1,629 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from gigaspeech_scoring import asr_text_post_processing -from train import get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=8, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def post_processing( - results: List[Tuple[str, List[str], List[str]]], -) -> List[Tuple[str, List[str], List[str]]]: - new_results = [] - for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((key, new_ref, new_hyp)) - return new_results - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = post_processing(results) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - gigaspeech = GigaSpeechAsrDataModule(args) - - dev_cuts = gigaspeech.dev_cuts() - test_cuts = gigaspeech.test_cuts() - - dev_dl = gigaspeech.test_dataloaders(dev_cuts) - test_dl = gigaspeech.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/encoder_interface.py deleted file mode 120000 index f58253127..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py deleted file mode 100755 index 4a44f7bcb..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --epoch 20 \ - --avg 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless2/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt.", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - # Load tokens.txt here - token_table = k2.SymbolTable.from_file(params.tokens) - - # Load id of the token and the vocab size - # is defined in local/train_bpe_model.py - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 # +1 for - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py deleted file mode 120000 index a6a4d12b1..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py deleted file mode 120000 index b82e115fc..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/optim.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py deleted file mode 120000 index db93d155b..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py deleted file mode 100755 index a7772b62f..000000000 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ /dev/null @@ -1,989 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless2/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --max-duration 120 - -# For mix precision training: - -./pruned_transducer_stateless2/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 0 \ - --use_fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ - --max-duration 200 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 500, - "reset_interval": 2000, - "valid_interval": 20000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - # parameters for Noam - "model_warm_step": 20000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - model_avg=model_avg, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - if params.print_diagnostics and batch_idx == 30: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - gigaspeech = GigaSpeechAsrDataModule(args) - - train_cuts = gigaspeech.train_cuts() - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = gigaspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = gigaspeech.dev_cuts() - valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/shared b/egs/gigaspeech/ASR/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/gigaspeech/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py deleted file mode 100644 index 0501461cd..000000000 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1,449 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import glob -import inspect -import logging -import re -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import lhotse -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class GigaSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=100, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - # GigaSpeech specific arguments - group.add_argument( - "--subset", - type=str, - default="XL", - help="Select the GigaSpeech subset (XS|S|M|L|XL)", - ) - group.add_argument( - "--small-dev", - type=str2bool, - default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info(f"About to get train {self.args.subset} cuts") - if self.args.subset == "XL": - filenames = glob.glob( - f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" - ) - pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) - idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) - sorted_filenames = [f[1] for f in idx_filenames] - logging.info( - f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" - ) - - cuts_train = lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) - else: - path = ( - self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" - ) - cuts_train = CutSet.from_jsonl_lazy(path) - return cuts_train - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" - ) - if self.args.small_dev: - return cuts_valid.subset(first=1000) - else: - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" - ) diff --git a/egs/gigaspeech/ASR/zipformer/beam_search.py b/egs/gigaspeech/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/gigaspeech/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py deleted file mode 100755 index 651f20cb6..000000000 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ /dev/null @@ -1,847 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Liyong Guo, -# Quandong Wang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) ctc-decoding -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --decoding-method ctc-decoding - -(2) 1best -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --decoding-method 1best - -(3) nbest -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --decoding-method nbest - -(4) nbest-rescoring -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --nbest-scale 1.0 \ - --lm-dir data/lm \ - --decoding-method nbest-rescoring - -(5) whole-lattice-rescoring -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --nbest-scale 1.0 \ - --lm-dir data/lm \ - --decoding-method whole-lattice-rescoring -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="ctc-decoding", - help="""Decoding method. - Supported values are: - - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (2) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - you have trained an RNN LM using ./rnn_lm/train.py - - (6) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--hlg-scale", - type=float, - default=0.6, - help="""The scale to be applied to `hlg.scores`. - """, - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The n-gram LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_decoding_params() -> AttributeDict: - """Parameters for decoding.""" - params = AttributeDict( - { - "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. - - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. - - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. - - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - torch.div( - supervisions["start_frame"], - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - supervisions["num_frames"], - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.decoding_method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - if params.decoding_method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.decoding_method in ["1best", "nbest"]: - if params.decoding_method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.decoding_method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.decoding_method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.decoding_method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - else: - assert False, f"Unsupported decoding method: {params.decoding_method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - word_table: - It is the word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - G=G, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - assert params.decoding_method in ( - "ctc-decoding", - "1best", - "nbest", - "nbest-rescoring", - "whole-lattice-rescoring", - "nbest-oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - params.vocab_size = num_classes - # and are defined in local/train_bpe_model.py - params.blank_id = 0 - - if params.decoding_method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - HLG.scores *= params.hlg_scale - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.decoding_method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.decoding_method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - gigaspeech = GigaSpeechAsrDataModule(args) - - test_clean_cuts = gigaspeech.test_clean_cuts() - test_other_cuts = gigaspeech.test_other_cuts() - - test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts) - test_other_dl = gigaspeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/zipformer/decode.py b/egs/gigaspeech/ASR/zipformer/decode.py deleted file mode 100755 index 3a0c71484..000000000 --- a/egs/gigaspeech/ASR/zipformer/decode.py +++ /dev/null @@ -1,1065 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from gigaspeech_scoring import asr_text_post_processing -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - add_model_arguments(parser) - - return parser - - -def post_processing( - results: List[Tuple[str, List[str], List[str]]], -) -> List[Tuple[str, List[str], List[str]]]: - new_results = [] - for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((key, new_ref, new_hyp)) - return new_results - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"-context-score-{params.context_score}" - return {prefix: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = post_processing(results) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append(line.strip()) - context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - gigaspeech = GigaSpeechAsrDataModule(args) - - dev_cuts = gigaspeech.dev_cuts() - test_cuts = gigaspeech.test_cuts() - - dev_dl = gigaspeech.test_dataloaders(dev_cuts) - test_dl = gigaspeech.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/zipformer/decode_stream.py b/egs/gigaspeech/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/gigaspeech/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/decoder.py b/egs/gigaspeech/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/gigaspeech/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/encoder_interface.py b/egs/gigaspeech/ASR/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/gigaspeech/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py deleted file mode 120000 index f9d756352..000000000 --- a/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx.py b/egs/gigaspeech/ASR/zipformer/export-onnx.py deleted file mode 100755 index 0f78cfe5b..000000000 --- a/egs/gigaspeech/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1,620 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/gigaspeech/ASR - -repo_url=https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal False \ - --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py and ./onnx_check.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear - ): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Zipformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) - - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "zipformer2", - "version": "1", - "model_author": "k2-fsa", - "comment": "non-streaming zipformer2", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_embed=model.encoder_embed, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/export.py b/egs/gigaspeech/ASR/zipformer/export.py deleted file mode 100755 index e45c96b57..000000000 --- a/egs/gigaspeech/ASR/zipformer/export.py +++ /dev/null @@ -1,522 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -Note: This is a example for gigaspeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -(1) Export to torchscript model using torch.jit.script() - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("jit_script.pt")`. - -Check ./jit_pretrained.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. -You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. - -Check ./jit_pretrained_streaming.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -- For non-streaming model: - -To use the generated file with `zipformer/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/gigaspeech/ASR - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -- For streaming model: - -To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/gigaspeech/ASR - - # simulated streaming decoding - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - - # chunk-wise streaming decoding - ./zipformer/streaming_decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -- non-streaming model: -https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 - # You will find the pre-trained models in exp dir -""" - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor, nn -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named jit_script.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class EncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Args: - features: (N, T, C) - feature_lengths: (N,) - """ - x, x_lens = self.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -class StreamingEncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """Streaming forward for encoder_embed and encoder. - - Args: - features: (N, T, C) - feature_lengths: (N,) - states: a list of Tensors - - Returns encoder outputs, output lengths, and updated states. - """ - chunk_size = self.chunk_size - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lengths, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - - # Wrap encoder and encoder_embed as a module - if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) - chunk_size = model.encoder.chunk_size - left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" - else: - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - filename = "jit_script.pt" - - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - model.save(str(params.exp_dir / filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py deleted file mode 120000 index a6a4d12b1..000000000 --- a/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/gigaspeech/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py deleted file mode 120000 index 9a8da5844..000000000 --- a/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 120000 index 1962351e9..000000000 --- a/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/joiner.py b/egs/gigaspeech/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/gigaspeech/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/model.py b/egs/gigaspeech/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/gigaspeech/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/my_profile.py b/egs/gigaspeech/ASR/zipformer/my_profile.py deleted file mode 120000 index 3a90b2628..000000000 --- a/egs/gigaspeech/ASR/zipformer/my_profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_check.py b/egs/gigaspeech/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_decode.py b/egs/gigaspeech/ASR/zipformer/onnx_decode.py deleted file mode 120000 index 0573b88c5..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 120000 index cfea104c2..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py deleted file mode 120000 index a3183ebf6..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py deleted file mode 120000 index a4fd76ac2..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py deleted file mode 120000 index f805e3761..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py deleted file mode 120000 index 8343d5079..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/optim.py b/egs/gigaspeech/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/gigaspeech/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained.py b/egs/gigaspeech/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/gigaspeech/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py deleted file mode 120000 index c2f6f6fc3..000000000 --- a/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/scaling.py b/egs/gigaspeech/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/gigaspeech/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/scaling_converter.py b/egs/gigaspeech/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/gigaspeech/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py b/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py deleted file mode 100755 index cb3fd0dc7..000000000 --- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,859 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import GigaSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - gigaspeech = GigaSpeechAsrDataModule(args) - - dev_cuts = gigaspeech.dev_cuts() - test_cuts = gigaspeech.test_cuts() - - test_sets = ["dev", "test"] - test_cuts = [dev_cuts, test_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/zipformer/subsampling.py b/egs/gigaspeech/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/gigaspeech/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/test_scaling.py b/egs/gigaspeech/ASR/zipformer/test_scaling.py deleted file mode 120000 index 715798436..000000000 --- a/egs/gigaspeech/ASR/zipformer/test_scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/test_subsampling.py b/egs/gigaspeech/ASR/zipformer/test_subsampling.py deleted file mode 120000 index bf0ee3d11..000000000 --- a/egs/gigaspeech/ASR/zipformer/test_subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py deleted file mode 100755 index 4c122effe..000000000 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ /dev/null @@ -1,1364 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Yifan Yang, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=1, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--scan-for-oom-batches", - type=str2bool, - default=False, - help=""" - Whether to scan for oom batches before training, this is helpful for - finding the suitable max_duration, you only need to run it once. - Caution: a little time consuming. - """, - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 500, - "reset_interval": 2000, - "valid_interval": 20000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_utt(c: Cut): - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - return T > 0 - - gigaspeech = GigaSpeechAsrDataModule(args) - - train_cuts = gigaspeech.train_cuts() - train_cuts = train_cuts.filter(remove_short_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = gigaspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = gigaspeech.dev_cuts() - valid_cuts = valid_cuts.filter(remove_short_utt) - valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics and params.scan_for_oom_batches: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/zipformer/zipformer.py b/egs/gigaspeech/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/gigaspeech/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/RESULTS.md b/egs/gigaspeech/KWS/RESULTS.md deleted file mode 100644 index 992240e14..000000000 --- a/egs/gigaspeech/KWS/RESULTS.md +++ /dev/null @@ -1,49 +0,0 @@ -# Results - -## zipformer transducer model - -This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details. - -The modeling units are 500 BPEs trained on gigaspeech transcripts. - -The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test set of gigaspeech (has 40 hours audios). - -We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands. - -The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-gigaspeech-20240219.tar.gz). - -Here is the results of a small test set which has 20 commands, we list the results of every commands, for -each metric there are two columns, one for the original model trained on gigaspeech XL subset, the other -for the finetune model finetuned on commands dataset. - -Commands | FN in positive set |FN in positive set | Recall | Recall | FP in negative set | FP in negative set| False alarm (time / hour) 40 hours | False alarm (time / hour) 40 hours | --- | -- | -- | -- | --| -- | -- | -- | -- -  | original | finetune | original | finetune | original | finetune | original | finetune -All | 43/307 | 4/307 | 86% | 98.7% | 1 | 24 | 0.025 | 0.6 -Lights on | 6/17 | 0/17 | 64.7% | 100% | 1 | 9 | 0.025 | 0.225 -Heat up | 5/14 | 1/14 | 64.3% | 92.9% | 0 | 1 | 0 | 0.025 -Volume down | 4/18 | 0/18 | 77.8% | 100% | 0 | 2 | 0 | 0.05 -Volume max | 4/17 | 0/17 | 76.5% | 100% | 0 | 0 | 0 | 0 -Volume mute | 4/16 | 0/16 | 75.0% | 100% | 0 | 0 | 0 | 0 -Too quiet | 3/17 | 0/17 | 82.4% | 100% | 0 | 4 | 0 | 0.1 -Lights off | 3/17 | 0/17 | 82.4% | 100% | 0 | 2 | 0 | 0.05 -Play music | 2/14 | 0/14 | 85.7% | 100% | 0 | 0 | 0 | 0 -Bring newspaper | 2/13 | 1/13 | 84.6% | 92.3% | 0 | 0 | 0 | 0 -Heat down | 2/16 | 2/16 | 87.5% | 87.5% | 0 | 1 | 0 | 0.025 -Volume up | 2/18 | 0/18 | 88.9% | 100% | 0 | 1 | 0 | 0.025 -Too loud | 1/13 | 0/13 | 92.3% | 100% | 0 | 0 | 0 | 0 -Resume music | 1/14 | 0/14 | 92.9% | 100% | 0 | 0 | 0 | 0 -Bring shoes | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 -Switch language | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 -Pause music | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 -Bring socks | 1/12 | 0/12 | 91.7% | 100% | 0 | 0 | 0 | 0 -Stop music | 0/15 | 0/15 | 100% | 100% | 0 | 0 | 0 | 0 -Turn it up | 0/15 | 0/15 | 100% | 100% | 0 | 3 | 0 | 0.075 -Turn it down | 0/16 | 0/16 | 100% | 100% | 0 | 1 | 0 | 0.025 - -This is the result of large test set, it has more than 200 commands, too many to list the details of each commands, so only an overall result here. - -Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours --- | -- | -- | -- | -- | -- | -- | -- | -- -  | original | finetune | original | finetune | original | finetune | original | finetune -All | 622/3994 | 79/ 3994 | 83.6% | 97.9% | 18/19930 | 52/19930 | 0.45 | 1.3 diff --git a/egs/gigaspeech/KWS/prepare.sh b/egs/gigaspeech/KWS/prepare.sh deleted file mode 100755 index 0b098190d..000000000 --- a/egs/gigaspeech/KWS/prepare.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=0 -stop_stage=100 - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Prepare gigaspeech dataset." - mkdir -p data/fbank - if [ ! -e data/fbank/.gigaspeech.done ]; then - pushd ../ASR - ./prepare.sh --stage 0 --stop-stage 9 - ./prepare.sh --stage 11 --stop-stage 11 - popd - pushd data/fbank - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) . - ln -svf $(realpath ../ASR/data/fbank/XL_split) . - ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/musan_feats) . - popd - pushd data - ln -svf $(realpath ../ASR/data/lang_bpe_500) . - popd - touch data/fbank/.gigaspeech.done - else - log "Gigaspeech dataset already exists, skipping." - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare open commands dataset." - mkdir -p data/fbank - if [ ! -e data/fbank/.fluent_speech_commands.done ]; then - pushd data - git clone https://github.com/pkufool/open-commands.git - ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt - ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt - pushd open-commands - ./script/prepare.sh --stage 2 --stop-stage 2 - ./script/prepare.sh --stage 6 --stop-stage 6 - popd - popd - pushd data/fbank - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) . - popd - touch data/fbank/.fluent_speech_commands.done - else - log "Fluent speech commands dataset already exists, skipping." - fi -fi diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh deleted file mode 100755 index 303abd718..000000000 --- a/egs/gigaspeech/KWS/run.sh +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -export CUDA_VISIBLE_DEVICES="0,1,2,3" -export PYTHONPATH=../../../:$PYTHONPATH - -stage=0 -stop_stage=100 - -. shared/parse_options.sh || exit 1 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Train a model." - if [ ! -e data/fbank/.gigaspeech.done ]; then - log "You need to run the prepare.sh first." - exit -1 - fi - - python ./zipformer/train.py \ - --world-size 4 \ - --exp-dir zipformer/exp \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --num-epochs 12 \ - --lr-epochs 1.5 \ - --use-fp16 1 \ - --start-epoch 1 \ - --subset XL \ - --bpe-model data/lang_bpe_500/bpe.model \ - --causal 1 \ - --max-duration 1000 -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Decode the model." - - export CUDA_VISIBLE_DEVICES="0" - for t in small large; do - python ./zipformer/decode.py \ - --epoch 12 \ - --avg 2 \ - --exp-dir ./zipformer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --test-set $t \ - --keywords-score 1.0 \ - --keywords-threshold 0.35 \ - --keywords-file ./data/commands_${t}.txt \ - --max-duration 3000 - done -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Export the model." - - python ./zipformer/export.py \ - --epoch 12 \ - --avg 2 \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 - - python ./zipformer/export-onnx-streaming.py \ - --exp-dir zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 12 \ - --avg 2 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --causal 1 -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 2: Finetune the model" - - # The following configuration of lr schedule should work well - # You may also tune the following parameters to adjust learning rate schedule - base_lr=0.0005 - lr_epochs=100 - lr_batches=100000 - - # We recommend to start from an averaged model - finetune_ckpt=zipformer/exp/pretrained.pt - - ./zipformer/finetune.py \ - --world-size 4 \ - --num-epochs 10 \ - --start-epoch 1 \ - --exp-dir zipformer/exp_finetune \ - --bpe-model data/lang_bpe_500/bpe.model \ - --use-fp16 1 \ - --use-mux 1 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --causal 1 \ - --base-lr $base_lr \ - --lr-epochs $lr_epochs \ - --lr-batches $lr_batches \ - --finetune-ckpt $finetune_ckpt \ - --max-duration 1500 -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 1: Decode the finetuned model." - export CUDA_VISIBLE_DEVICES="0" - for t in small large; do - python ./zipformer/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./zipformer/exp_finetune \ - --bpe-model data/lang_bpe_500/bpe.model \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --test-set $t \ - --keywords-score 1.0 \ - --keywords-threshold 0.35 \ - --keywords-file ./data/commands_${t}.txt \ - --max-duration 3000 - done -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 2: Export the finetuned model." - - python ./zipformer/export.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./zipformer/exp_finetune \ - --tokens data/lang_bpe_500/tokens.txt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 - - python ./zipformer/export-onnx-streaming.py \ - --exp-dir zipformer/exp_finetune \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --causal 1 -fi diff --git a/egs/gigaspeech/KWS/shared b/egs/gigaspeech/KWS/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/gigaspeech/KWS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py deleted file mode 100644 index ccc602404..000000000 --- a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import glob -import inspect -import logging -import re -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import lhotse -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class GigaSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - # GigaSpeech specific arguments - group.add_argument( - "--subset", - type=str, - default="XL", - help="Select the GigaSpeech subset (XS|S|M|L|XL)", - ) - group.add_argument( - "--small-dev", - type=str2bool, - default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info(f"About to get train {self.args.subset} cuts") - if self.args.subset == "XL": - filenames = glob.glob( - f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" - ) - pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) - idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) - sorted_filenames = [f[1] for f in idx_filenames] - logging.info( - f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" - ) - - cuts_train = lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) - else: - path = ( - self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" - ) - cuts_train = CutSet.from_jsonl_lazy(path) - return cuts_train - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" - ) - if self.args.small_dev: - return cuts_valid.subset(first=1000) - else: - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" - ) - - @lru_cache() - def fsc_train_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz" - ) - - @lru_cache() - def fsc_valid_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands valid cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def fsc_test_small_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands small test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz" - ) - - @lru_cache() - def fsc_test_large_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands large test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz" - ) diff --git a/egs/gigaspeech/KWS/zipformer/beam_search.py b/egs/gigaspeech/KWS/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/gigaspeech/KWS/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/decode-asr.py b/egs/gigaspeech/KWS/zipformer/decode-asr.py deleted file mode 100755 index 149b8bed0..000000000 --- a/egs/gigaspeech/KWS/zipformer/decode-asr.py +++ /dev/null @@ -1,1066 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from gigaspeech_scoring import asr_text_post_processing -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - add_model_arguments(parser) - - return parser - - -def post_processing( - results: List[Tuple[str, List[str], List[str]]], -) -> List[Tuple[str, List[str], List[str]]]: - new_results = [] - for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((key, new_ref, new_hyp)) - return new_results - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"-context-score-{params.context_score}" - return {prefix: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = post_processing(results) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append(line.strip()) - context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - gigaspeech = GigaSpeechAsrDataModule(args) - - test_cuts = gigaspeech.test_cuts() - test_dl = gigaspeech.test_dataloaders(test_cuts) - - test_fsc_cuts = gigaspeech.fsc_test_large_cuts() - test_fsc_dl = gigaspeech.test_dataloaders(test_fsc_cuts) - - test_sets = ["test", "fsc_test"] - test_dls = [test_dl, test_fsc_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py deleted file mode 100755 index 0df2ec356..000000000 --- a/egs/gigaspeech/KWS/zipformer/decode.py +++ /dev/null @@ -1,687 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --keywords-file keywords.txt \ - --beam-size 4 -""" - -import argparse -import logging -import math -import os -from collections import defaultdict -from dataclasses import dataclass, field -from pathlib import Path -from typing import Dict, List, Optional, Set, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from beam_search import keywords_search -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -@dataclass -class KwMetric: - TP: int = 0 # True positive - FN: int = 0 # False negative - FP: int = 0 # False positive - TN: int = 0 # True negative - FN_list: List[str] = field(default_factory=list) - FP_list: List[str] = field(default_factory=list) - TP_list: List[str] = field(default_factory=list) - - def __str__(self) -> str: - return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})" - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--beam", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--keywords-file", - type=str, - help="File contains keywords.", - ) - - parser.add_argument( - "--test-set", - type=str, - default="small", - help="small or large", - ) - - parser.add_argument( - "--keywords-score", - type=float, - default=1.5, - help=""" - The default boosting score (token level) for keywords. it will boost the - paths that match keywords to make them survive beam search. - """, - ) - - parser.add_argument( - "--keywords-threshold", - type=float, - default=0.35, - help="The default threshold (probability) to trigger the keyword.", - ) - - parser.add_argument( - "--num-tailing-blanks", - type=int, - default=1, - help="The number of tailing blanks should have after hitting one keyword.", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - keywords_graph: Optional[ContextGraph] = None, -) -> List[List[Tuple[str, Tuple[int, int]]]]: - """Decode one batch and return the result in a list. - - The length of the list equals to batch size, the i-th element contains the - triggered keywords for the i-th utterance in the given batch. The triggered - keywords are also a list, each of it contains a tuple of hitting keyword and - the corresponding start timestamps and end timestamps of the hitting keyword. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - keywords_graph: - The graph containing keywords. - Returns: - Return the decoding result. See above description for the format of - the returned list. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - ans_dict = keywords_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - keywords_graph=keywords_graph, - beam=params.beam, - num_tailing_blanks=params.num_tailing_blanks, - blank_penalty=params.blank_penalty, - ) - - hyps = [] - for ans in ans_dict: - hyp = [] - for hit in ans: - hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1]))) - hyps.append(hyp) - - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - keywords_graph: ContextGraph, - keywords: Set[str], - test_only_keywords: bool, -) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - keywords_graph: - The graph containing keywords. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 50 - - results = [] - metric = {"all": KwMetric()} - for k in keywords: - metric[k] = KwMetric() - - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps = decode_one_batch( - params=params, - model=model, - sp=sp, - keywords_graph=keywords_graph, - batch=batch, - ) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = ref_text.upper() - ref_words = ref_text.split() - hyp_words = [x[0] for x in hyp_words] - # for computing WER - this_batch.append((cut_id, ref_words, " ".join(hyp_words).split())) - hyp_set = set(hyp_words) # each item is a keyword phrase - if len(hyp_words) > 1: - logging.warning( - f"Cut {cut_id} triggers more than one keywords : {hyp_words}," - f"please check the transcript to see if it really has more " - f"than one keywords, if so consider splitting this audio and" - f"keep only one keyword for each audio." - ) - hyp_str = " | ".join( - hyp_words - ) # The triggered keywords for this utterance. - TP = False - FP = False - for x in hyp_set: - assert x in keywords, x # can only trigger keywords - if (test_only_keywords and x == ref_text) or ( - not test_only_keywords and x in ref_text - ): - TP = True - metric[x].TP += 1 - metric[x].TP_list.append(f"({ref_text} -> {x})") - if (test_only_keywords and x != ref_text) or ( - not test_only_keywords and x not in ref_text - ): - FP = True - metric[x].FP += 1 - metric[x].FP_list.append(f"({ref_text} -> {x})") - if TP: - metric["all"].TP += 1 - if FP: - metric["all"].FP += 1 - TN = True # all keywords are true negative then the summery is true negative. - FN = False - for x in keywords: - if x not in ref_text and x not in hyp_set: - metric[x].TN += 1 - continue - - TN = False - if (test_only_keywords and x == ref_text) or ( - not test_only_keywords and x in ref_text - ): - fn = True - for y in hyp_set: - if (test_only_keywords and y == ref_text) or ( - not test_only_keywords and y in ref_text - ): - fn = False - break - if fn: - FN = True - metric[x].FN += 1 - metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") - if TN: - metric["all"].TN += 1 - if FN: - metric["all"].FN += 1 - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results, metric - - -def save_results( - params: AttributeDict, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], - metric: KwMetric, -): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" - - with open(metric_filename, "w") as of: - width = 10 - for key, item in sorted( - metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True - ): - acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) - precision = ( - 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP) - ) - recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN) - fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN) - s = f"{key}:\n" - s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" - s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" - s += f"\tAccuracy: {acc:.3f}\n" - s += f"\tPrecision: {precision:.3f}\n" - s += f"\tRecall(PPR): {recall:.3f}\n" - s += f"\tFPR: {fpr:.3f}\n" - s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" - if key != "all": - s += f"\tTP list: {' # '.join(item.TP_list)}\n" - s += f"\tFP list: {' # '.join(item.FP_list)}\n" - s += f"\tFN list: {' # '.join(item.FN_list)}\n" - of.write(s + "\n") - if key == "all": - logging.info(s) - of.write(f"\n\n{params.keywords_config}") - - logging.info("Wrote metric stats to {}".format(metric_filename)) - - -@torch.no_grad() -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "kws" - - params.suffix = params.test_set - if params.iter > 0: - params.suffix += f"-iter-{params.iter}-avg-{params.avg}" - else: - params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - params.suffix += f"-score-{params.keywords_score}" - params.suffix += f"-threshold-{params.keywords_threshold}" - params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" - if params.blank_penalty != 0: - params.suffix += f"-blank-penalty-{params.blank_penalty}" - params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - phrases = [] - token_ids = [] - keywords_scores = [] - keywords_thresholds = [] - keywords_config = [] - with open(params.keywords_file, "r") as f: - for line in f.readlines(): - keywords_config.append(line) - score = 0 - threshold = 0 - keyword = [] - words = line.strip().upper().split() - for word in words: - word = word.strip() - if word[0] == ":": - score = float(word[1:]) - continue - if word[0] == "#": - threshold = float(word[1:]) - continue - keyword.append(word) - keyword = " ".join(keyword) - phrases.append(keyword) - token_ids.append(sp.encode(keyword)) - keywords_scores.append(score) - keywords_thresholds.append(threshold) - - params.keywords_config = "".join(keywords_config) - - keywords_graph = ContextGraph( - context_score=params.keywords_score, ac_threshold=params.keywords_threshold - ) - keywords_graph.build( - token_ids=token_ids, - phrases=phrases, - scores=keywords_scores, - ac_thresholds=keywords_thresholds, - ) - keywords = set(phrases) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - gigaspeech = GigaSpeechAsrDataModule(args) - - test_cuts = gigaspeech.test_cuts() - test_dl = gigaspeech.test_dataloaders(test_cuts) - - if params.test_set == "small": - test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts() - test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts) - test_sets = ["small-fsc", "test"] - test_dls = [test_fsc_small_dl, test_dl] - else: - assert params.test_set == "large", params.test_set - test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts() - test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts) - test_sets = ["large-fsc", "test"] - test_dls = [test_fsc_large_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results, metric = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - keywords_graph=keywords_graph, - keywords=keywords, - test_only_keywords="fsc" in test_set, - ) - - save_results( - params=params, - test_set_name=test_set, - results=results, - metric=metric, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/KWS/zipformer/decoder.py b/egs/gigaspeech/KWS/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/gigaspeech/KWS/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/encoder_interface.py b/egs/gigaspeech/KWS/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/gigaspeech/KWS/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/export.py b/egs/gigaspeech/KWS/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/gigaspeech/KWS/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py deleted file mode 100755 index a7ba56127..000000000 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ /dev/null @@ -1,642 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Yifan Yang, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -# For non-streaming model training: -./zipformer/finetune.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/fintune.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut, CutSet -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from train import ( - add_model_arguments, - add_training_arguments, - compute_loss, - compute_validation_loss, - display_and_save_batch, - get_adjusted_batch_count, - get_model, - get_params, - load_checkpoint_if_available, - save_checkpoint, - scan_pessimistic_batches_for_oom, - set_batch_count, -) - -from icefall import diagnostics -from icefall.checkpoint import remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_finetune_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--use-mux", - type=str2bool, - default=False, - help=""" - Whether to adapt. If true, we will mix 5% of the new data - with 95% of the original data to fine-tune. - """, - ) - - parser.add_argument( - "--init-modules", - type=str, - default=None, - help=""" - Modules to be initialized. It matches all parameters starting with - a specific key. The keys are given with Comma seperated. If None, - all modules will be initialised. For example, if you only want to - initialise all parameters staring with "encoder", use "encoder"; - if you want to initialise parameters starting with encoder or decoder, - use "encoder,joiner". - """, - ) - - parser.add_argument( - "--finetune-ckpt", - type=str, - default=None, - help="Fine-tuning from which checkpoint (a path to a .pt file)", - ) - - parser.add_argument( - "--continue-finetune", - type=str2bool, - default=False, - help="Continue finetuning or finetune from pre-trained model", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - add_training_arguments(parser) - add_model_arguments(parser) - add_finetune_arguments(parser) - - return parser - - -def load_model_params( - ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True -): - """Load model params from checkpoint - - Args: - ckpt (str): Path to the checkpoint - model (nn.Module): model to be loaded - - """ - logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") - - # if module list is empty, load the whole model from ckpt - if not init_modules: - if next(iter(checkpoint["model"])).startswith("module."): - logging.info("Loading checkpoint saved by DDP") - - dst_state_dict = model.state_dict() - src_state_dict = checkpoint["model"] - for key in dst_state_dict.keys(): - src_key = "{}.{}".format("module", key) - dst_state_dict[key] = src_state_dict.pop(src_key) - assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=strict) - else: - model.load_state_dict(checkpoint["model"], strict=strict) - else: - src_state_dict = checkpoint["model"] - dst_state_dict = model.state_dict() - for module in init_modules: - logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [ - k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") - ] - dst_keys = [ - k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") - ] - assert set(src_keys) == set(dst_keys) # two sets should match exactly - for key in src_keys: - dst_state_dict[key] = src_state_dict.pop(key) - - model.load_state_dict(dst_state_dict, strict=strict) - - return None - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params) + 100000) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - - # if params.continue_finetune: - # set_batch_count(model, params.batch_idx_train) - # else: - # set_batch_count(model, params.batch_idx_train + 100000) - - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - - if params.continue_finetune: - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - else: - modules = params.init_modules.split(",") if params.init_modules else None - checkpoints = load_model_params( - ckpt=params.finetune_ckpt, model=model, init_modules=modules - ) - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_utt(c: Cut): - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - return T > 0 - - gigaspeech = GigaSpeechAsrDataModule(args) - - if params.use_mux: - train_cuts = CutSet.mux( - gigaspeech.train_cuts(), - gigaspeech.fsc_train_cuts(), - weights=[0.9, 0.1], - ) - else: - train_cuts = gigaspeech.fsc_train_cuts() - - train_cuts = train_cuts.filter(remove_short_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = gigaspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = gigaspeech.fsc_valid_cuts() - valid_cuts = valid_cuts.filter(remove_short_utt) - valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics and params.scan_for_oom_batches: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py deleted file mode 120000 index 4ee54fff5..000000000 --- a/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py +++ /dev/null @@ -1 +0,0 @@ -../../ASR/zipformer/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/joiner.py b/egs/gigaspeech/KWS/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/gigaspeech/KWS/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/model.py b/egs/gigaspeech/KWS/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/gigaspeech/KWS/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/optim.py b/egs/gigaspeech/KWS/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/gigaspeech/KWS/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/scaling.py b/egs/gigaspeech/KWS/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/gigaspeech/KWS/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/subsampling.py b/egs/gigaspeech/KWS/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/gigaspeech/KWS/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py deleted file mode 100755 index 39d8fc6cd..000000000 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ /dev/null @@ -1,1367 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Yifan Yang, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import GigaSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="1,1,1,1,1,1", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="192,192,192,192,192,192", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="128,128,128,128,128,128", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="128,128,128,128,128,128", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=320, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=320, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=True, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - add_training_arguments(parser) - add_model_arguments(parser) - - return parser - - -def add_training_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=1, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--scan-for-oom-batches", - type=str2bool, - default=False, - help=""" - Whether to scan for oom batches before training, this is helpful for - finding the suitable max_duration, you only need to run it once. - Caution: a little time consuming. - """, - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 500, - "reset_interval": 2000, - "valid_interval": 20000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_utt(c: Cut): - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - return T > 0 - - gigaspeech = GigaSpeechAsrDataModule(args) - - train_cuts = gigaspeech.train_cuts() - train_cuts = train_cuts.filter(remove_short_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = gigaspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = gigaspeech.dev_cuts() - valid_cuts = valid_cuts.filter(remove_short_utt) - valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics and params.scan_for_oom_batches: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - GigaSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/KWS/zipformer/zipformer.py b/egs/gigaspeech/KWS/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/gigaspeech/KWS/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/README.md b/egs/ksponspeech/ASR/README.md old mode 100644 new mode 100755 diff --git a/egs/ksponspeech/ASR/RESULTS.md b/egs/ksponspeech/ASR/RESULTS.md old mode 100644 new mode 100755 diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/README.md old mode 100644 new mode 100755 diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/libricss/SURT/README.md b/egs/libricss/SURT/README.md deleted file mode 100644 index 10a1aaad1..000000000 --- a/egs/libricss/SURT/README.md +++ /dev/null @@ -1,249 +0,0 @@ -# Introduction - -This is a multi-talker ASR recipe for the LibriCSS dataset. We train a Streaming -Unmixing and Recognition Transducer (SURT) model for the task. In this README, -we will describe the task, the model, and the training process. We will also -provide links to pre-trained models and training logs. - -## Task - -LibriCSS is a multi-talker meeting corpus formed from mixing together LibriSpeech utterances -and replaying in a real meeting room. It consists of 10 1-hour sessions of audio, each -recorded on a 7-channel microphone. The sessions are recorded at a sampling rate of 16 kHz. -For more information, refer to the paper: -Z. Chen et al., "Continuous speech separation: dataset and analysis," -ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), -Barcelona, Spain, 2020 - -In this recipe, we perform the "continuous, streaming, multi-talker ASR" task on LibriCSS. - -* By "continuous", we mean that the model should be able to transcribe unsegmented audio -without the need of an external VAD. -* By "streaming", we mean that the model has limited right context. We use a right-context -of at most 32 frames (320 ms). -* By "multi-talker", we mean that the model should be able to transcribe overlapping speech -from multiple speakers. - -For now, we do not care about speaker attribution, i.e., the transcription is speaker -agnostic. The evaluation depends on the particular model type. In this case, we use -the optimal reference combination WER (ORC-WER) metric as implemented in the -[meeteval](https://github.com/fgnt/meeteval) toolkit. - -## Model - -We use the Streaming Unmixing and Recognition Transducer (SURT) model for this task. -The model is based on the papers: - -- Lu, Liang et al. “Streaming End-to-End Multi-Talker Speech Recognition.” IEEE Signal Processing Letters 28 (2020): 803-807. -- Raj, Desh et al. “Continuous Streaming Multi-Talker ASR with Dual-Path Transducers.” ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2021): 7317-7321. - -The model is a combination of a speech separation model and a speech recognition model, -but trained end-to-end with a single loss function. The overall architecture is shown -in the figure below. Note that this architecture is slightly different from the one -in the above papers. A detailed description of the model can be found in the following -paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559). - -

- - - Streaming Unmixing and Recognition Transducer - -

- -In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network -and a Zipfomer-based recognition network. But other combinations are possible as well. - -## Training objective - -We train the model using the pruned transducer loss, similar to other ASR recipes in -icefall. However, an important consideration is how to assign references to the output -channels (2 in this case). For this, we use the heuristic error assignment training (HEAT) -strategy, which assigns references to the first available channel based on their start -times. An illustrative example is shown in the figure below: - -

- - - Illustration of HEAT-based reference assignment. - -

- -## Description of the recipe - -### Pre-requisites - -The recipes in this directory need the following packages to be installed: - -- [meeteval](https://github.com/fgnt/meeteval) -- [einops](https://github.com/arogozhnikov/einops) - -Additionally, we initialize the "recognition" transducer with a pre-trained model, -trained on LibriSpeech. For this, please run the following from within `egs/librispeech/ASR`: - -```bash -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0,1,2,3" -python pruned_transducer_stateless7_streaming/train.py \ - --use-fp16 True \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --world-size 4 \ - --max-duration 800 \ - --num-epochs 10 \ - --keep-last-k 1 \ - --manifest-dir data/manifests \ - --enable-musan true \ - --master-port 54321 \ - --bpe-model data/lang_bpe_500/bpe.model \ - --num-encoder-layers 2,2,2,2,2 \ - --feedforward-dims 768,768,768,768,768 \ - --nhead 8,8,8,8,8 \ - --encoder-dims 256,256,256,256,256 \ - --attention-dims 192,192,192,192,192 \ - --encoder-unmasked-dims 192,192,192,192,192 \ - --zipformer-downsampling-factors 1,2,4,8,2 \ - --cnn-module-kernels 31,31,31,31,31 \ - --decoder-dim 512 \ - --joiner-dim 512 -``` - -The above is for SURT-base (~26M). For SURT-large (~38M), use `--num-encoder-layers 2,4,3,2,4`. - -Once the above model is trained for 10 epochs, copy it to `egs/libricss/SURT/exp`: - -```bash -cp -r pruned_transducer_stateless7_streaming/exp/epoch-10.pt exp/zipformer_base.pt -``` - -**NOTE:** We also provide this pre-trained checkpoint (see the section below), so you can skip -the above step if you want. - -### Training - -To train the model, run the following from within `egs/libricss/SURT`: - -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -python dprnn_zipformer/train.py \ - --use-fp16 True \ - --exp-dir dprnn_zipformer/exp/surt_base \ - --world-size 4 \ - --max-duration 500 \ - --max-duration-valid 250 \ - --max-cuts 200 \ - --num-buckets 50 \ - --num-epochs 30 \ - --enable-spec-aug True \ - --enable-musan False \ - --ctc-loss-scale 0.2 \ - --heat-loss-scale 0.2 \ - --base-lr 0.004 \ - --model-init-ckpt exp/zipformer_base.pt \ - --chunk-width-randomization True \ - --num-mask-encoder-layers 4 \ - --num-encoder-layers 2,2,2,2,2 -``` - -The above is for SURT-base (~26M). For SURT-large (~38M), use: - -```bash - --num-mask-encoder-layers 6 \ - --num-encoder-layers 2,4,3,2,4 \ - --model-init-ckpt exp/zipformer_large.pt \ -``` - -**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM. - -### Adaptation - -The training step above only trains on simulated mixtures. For best results, we also -adapt the final model on the LibriCSS dev set. For this, run the following from within -`egs/libricss/SURT`: - -```bash -export CUDA_VISIBLE_DEVICES="0" - -python dprnn_zipformer/train_adapt.py \ - --use-fp16 True \ - --exp-dir dprnn_zipformer/exp/surt_base_adapt \ - --world-size 1 \ - --max-duration 500 \ - --max-duration-valid 250 \ - --max-cuts 200 \ - --num-buckets 50 \ - --num-epochs 8 \ - --lr-epochs 2 \ - --enable-spec-aug True \ - --enable-musan False \ - --ctc-loss-scale 0.2 \ - --base-lr 0.0004 \ - --model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \ - --chunk-width-randomization True \ - --num-mask-encoder-layers 4 \ - --num-encoder-layers 2,2,2,2,2 -``` - -For SURT-large, use the following config: - -```bash - --num-mask-encoder-layers 6 \ - --num-encoder-layers 2,4,3,2,4 \ - --model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \ - --num-epochs 15 \ - --lr-epochs 4 \ -``` - - -### Decoding - -To decode the model, run the following from within `egs/libricss/SURT`: - -#### Greedy search - -```bash -export CUDA_VISIBLE_DEVICES="0" - -python dprnn_zipformer/decode.py \ - --epoch 8 --avg 1 --use-averaged-model False \ - --exp-dir dprnn_zipformer/exp/surt_base_adapt \ - --max-duration 250 \ - --decoding-method greedy_search -``` - -#### Beam search - -```bash -python dprnn_zipformer/decode.py \ - --epoch 8 --avg 1 --use-averaged-model False \ - --exp-dir dprnn_zipformer/exp/surt_base_adapt \ - --max-duration 250 \ - --decoding-method modified_beam_search \ - --beam-size 4 -``` - -## Results (using beam search) - -#### IHM-Mix - -| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. | -|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:| -| dprnn_zipformer (base) | 26.7 | 5.1 | 4.2 | 13.7 | 18.7 | 20.5 | 20.6 | 13.8 | -| dprnn_zipformer (large) | 37.9 | 4.6 | 3.8 | 12.7 | 14.3 | 16.7 | 21.2 | 12.2 | - -#### SDM - -| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. | -|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:| -| dprnn_zipformer (base) | 26.7 | 6.8 | 7.2 | 21.4 | 24.5 | 28.6 | 31.2 | 20.0 | -| dprnn_zipformer (large) | 37.9 | 6.4 | 6.9 | 17.9 | 19.7 | 25.2 | 25.5 | 16.9 | - -## Pre-trained models and logs - -* Pre-trained models: - -* Training logs: - - surt_base: - - surt_base_adapt: - - surt_large: - - surt_large_adapt: diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py deleted file mode 100644 index 500df9ea4..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# Copyright 2023 Johns Hopkins Univrtsity (Author: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutMix, - DynamicBucketingSampler, - K2SurtDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriCssAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--max-duration-valid", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--max-cuts", - type=int, - default=100, - help="Maximum number of cuts in a single batch. You can " - "reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - return_sources: bool = True, - strict: bool = True, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SurtDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - return_sources=return_sources, - strict=strict, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - quadratic_duration=30.0, - max_cuts=self.args.max_cuts, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - max_cuts=self.args.max_cuts, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - - logging.info("About to create dev dataset") - validate = K2SurtDataset( - input_strategy=OnTheFlyFeatures( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - return_sources=False, - strict=False, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration_valid, - max_cuts=self.args.max_cuts, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SurtDataset( - input_strategy=OnTheFlyFeatures( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - return_sources=False, - strict=False, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration_valid, - max_cuts=self.args.max_cuts, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def lsmix_cuts( - self, - rvb_affix: str = "clean", - type_affix: str = "full", - sources: bool = True, - ) -> CutSet: - logging.info("About to get train cuts") - source_affix = "_sources" if sources else "" - cs = load_manifest_lazy( - self.args.manifest_dir - / f"cuts_train_{rvb_affix}_{type_affix}{source_affix}.jsonl.gz" - ) - cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0) - return cs - - @lru_cache() - def libricss_cuts(self, split="dev", type="sdm") -> CutSet: - logging.info(f"About to get LibriCSS {split} {type} cuts") - cs = load_manifest_lazy( - self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz" - ) - return cs diff --git a/egs/libricss/SURT/dprnn_zipformer/beam_search.py b/egs/libricss/SURT/dprnn_zipformer/beam_search.py deleted file mode 100644 index c8e4643d0..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/beam_search.py +++ /dev/null @@ -1,730 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import torch -from model import SURT - -from icefall import NgramLmStateCost -from icefall.utils import DecodingResults - - -def greedy_search( - model: SURT, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `SURT`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 4 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: SURT, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The SURT model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - timestamps[i].append(t) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def modified_beam_search( - model: SURT, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The SURT model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def beam_search( - model: SURT, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_SURT.py#L247 is used as a reference. - - Args: - model: - An instance of `SURT`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans diff --git a/egs/libricss/SURT/dprnn_zipformer/decode.py b/egs/libricss/SURT/dprnn_zipformer/decode.py deleted file mode 100755 index 6abbffe00..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/decode.py +++ /dev/null @@ -1,654 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./dprnn_zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --use-averaged-model true \ - --exp-dir ./dprnn_zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./dprnn_zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --use-averaged-model true \ - --exp-dir ./dprnn_zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriCssAsrDataModule -from beam_search import ( - beam_search, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.utils import EPSILON -from train import add_model_arguments, get_params, get_surt_model - -from icefall import LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_surt_error_stats, -) - -OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="dprnn_zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--save-masks", - type=str2bool, - default=False, - help="""If true, save masks generated by unmixing module.""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - feature_lens = batch["input_lens"].to(device) - - # Apply the mask encoder - B, T, F = feature.shape - processed = model.mask_encoder(feature) # B,T,F*num_channels - masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) - x_masked = [feature * m for m in masks] - - masks_dict = {} - if params.save_masks: - # To save the masks, we split them by batch and trim each mask to the length of - # the corresponding feature. We save them in a dict, where the key is the - # cut ID and the value is the mask. - for i in range(B): - mask = torch.cat( - [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)], - dim=-1, - ) - mask = mask.cpu().numpy() - masks_dict[batch["cuts"][i].id] = mask - - # Recognition - # Concatenate the inputs along the batch axis - h = torch.cat(x_masked, dim=0) - h_lens = feature_lens.repeat(params.num_channels) - encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) - - if model.joint_encoder_layer is not None: - encoder_out = model.joint_encoder_layer(encoder_out) - - def _group_channels(hyps: List[str]) -> List[List[str]]: - """ - Currently we have a batch of size M*B, where M is the number of - channels and B is the batch size. We need to group the hypotheses - into B groups, each of which contains M hypotheses. - - Example: - hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2'] - _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']] - """ - assert len(hyps) == B * params.num_channels - out_hyps = [] - for i in range(B): - out_hyps.append(hyps[i::B]) - return out_hyps - - hyps = [] - if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp)) - - if params.decoding_method == "greedy_search": - return {"greedy_search": _group_channels(hyps)}, masks_dict - else: - return {f"beam_size_{params.beam_size}": _group_channels(hyps)}, masks_dict - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - masks = {} - for batch_idx, batch in enumerate(dl): - cut_ids = [cut.id for cut in batch["cuts"]] - cuts_batch = batch["cuts"] - - hyps_dict, masks_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - ) - masks.update(masks_dict) - - for name, hyps in hyps_dict.items(): - this_batch = [] - for cut_id, hyp_words in zip(cut_ids, hyps): - # Reference is a list of supervision texts sorted by start time. - ref_words = [ - s.text.strip() - for s in sorted( - cuts_batch[cut_id].supervisions, key=lambda s: s.start - ) - ] - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(cut_ids) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results, masks_dict - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_surt_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - num_channels=params.num_channels, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -def save_masks( - params: AttributeDict, - test_set_name: str, - masks: List[torch.Tensor], -): - masks_path = params.res_dir / f"masks-{test_set_name}.txt" - torch.save(masks, masks_path) - logging.info(f"The masks are stored in {masks_path}") - - -@torch.no_grad() -def main(): - parser = get_parser() - LmScorer.add_arguments(parser) - LibriCssAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - ), f"Decoding method {params.decoding_method} is not supported." - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - libricss = LibriCssAsrDataModule(args) - - dev_cuts = libricss.libricss_cuts(split="dev", type="ihm-mix").to_eager() - dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS] - test_cuts = libricss.libricss_cuts(split="test", type="ihm-mix").to_eager() - test_cuts_grouped = [ - test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS - ] - - for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS): - dev_dl = libricss.test_dataloaders(dev_set) - results_dict, masks = decode_dataset( - dl=dev_dl, - params=params, - model=model, - sp=sp, - ) - - save_results( - params=params, - test_set_name=f"dev_{ol}", - results_dict=results_dict, - ) - - if params.save_masks: - save_masks( - params=params, - test_set_name=f"dev_{ol}", - masks=masks, - ) - - for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS): - test_dl = libricss.test_dataloaders(test_set) - results_dict, masks = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - ) - - save_results( - params=params, - test_set_name=f"test_{ol}", - results_dict=results_dict, - ) - - if params.save_masks: - save_masks( - params=params, - test_set_name=f"test_{ol}", - masks=masks, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_zipformer/decoder.py b/egs/libricss/SURT/dprnn_zipformer/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/dprnn.py b/egs/libricss/SURT/dprnn_zipformer/dprnn.py deleted file mode 100644 index 440dea885..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/dprnn.py +++ /dev/null @@ -1,305 +0,0 @@ -import random -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from einops import rearrange -from scaling import ActivationBalancer, BasicNorm, DoubleSwish, ScaledLinear, ScaledLSTM -from torch.autograd import Variable - -EPS = torch.finfo(torch.get_default_dtype()).eps - - -def _pad_segment(input, segment_size): - # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L342 - # input is the features: (B, N, T) - batch_size, dim, seq_len = input.shape - segment_stride = segment_size // 2 - - rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size - if rest > 0: - pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) - input = torch.cat([input, pad], 2) - - pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type()) - input = torch.cat([pad_aux, input, pad_aux], 2) - - return input, rest - - -def split_feature(input, segment_size): - # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L358 - # split the feature into chunks of segment size - # input is the features: (B, N, T) - - input, rest = _pad_segment(input, segment_size) - batch_size, dim, seq_len = input.shape - segment_stride = segment_size // 2 - - segments1 = ( - input[:, :, :-segment_stride] - .contiguous() - .view(batch_size, dim, -1, segment_size) - ) - segments2 = ( - input[:, :, segment_stride:] - .contiguous() - .view(batch_size, dim, -1, segment_size) - ) - segments = ( - torch.cat([segments1, segments2], 3) - .view(batch_size, dim, -1, segment_size) - .transpose(2, 3) - ) - - return segments.contiguous(), rest - - -def merge_feature(input, rest): - # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L385 - # merge the splitted features into full utterance - # input is the features: (B, N, L, K) - - batch_size, dim, segment_size, _ = input.shape - segment_stride = segment_size // 2 - input = ( - input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2) - ) # B, N, K, L - - input1 = ( - input[:, :, :, :segment_size] - .contiguous() - .view(batch_size, dim, -1)[:, :, segment_stride:] - ) - input2 = ( - input[:, :, :, segment_size:] - .contiguous() - .view(batch_size, dim, -1)[:, :, :-segment_stride] - ) - - output = input1 + input2 - if rest > 0: - output = output[:, :, :-rest] - - return output.contiguous() # B, N, T - - -class RNNEncoderLayer(nn.Module): - """ - RNNEncoderLayer is made up of lstm and feedforward networks. - Args: - input_size: - The number of expected features in the input (required). - hidden_size: - The hidden dimension of rnn layer. - dropout: - The dropout value (default=0.1). - layer_dropout: - The dropout value for model-level warmup (default=0.075). - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - dropout: float = 0.1, - bidirectional: bool = False, - ) -> None: - super(RNNEncoderLayer, self).__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - assert hidden_size >= input_size, (hidden_size, input_size) - self.lstm = ScaledLSTM( - input_size=input_size, - hidden_size=hidden_size // 2 if bidirectional else hidden_size, - proj_size=0, - num_layers=1, - dropout=0.0, - batch_first=True, - bidirectional=bidirectional, - ) - self.norm_final = BasicNorm(input_size) - - # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa - self.balancer = ActivationBalancer( - num_channels=input_size, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0, - ) - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer. - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (1, N, input_size); - states[1] is the cell states of all layers, - with shape of (1, N, hidden_size). - """ - src_orig = src - - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - alpha = warmup if self.training else 1.0 - - # lstm module - src_lstm, new_states = self.lstm(src, states) - src = self.dropout(src_lstm) + src - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src - - -# dual-path RNN -class DPRNN(nn.Module): - """Deep dual-path RNN. - Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py - - args: - input_size: int, dimension of the input feature. The input should have shape - (batch, seq_len, input_size). - hidden_size: int, dimension of the hidden state. - output_size: int, dimension of the output size. - dropout: float, dropout ratio. Default is 0. - num_blocks: int, number of stacked RNN layers. Default is 1. - """ - - def __init__( - self, - feature_dim, - input_size, - hidden_size, - output_size, - dropout=0.1, - num_blocks=1, - segment_size=50, - chunk_width_randomization=False, - ): - super().__init__() - - self.input_size = input_size - self.output_size = output_size - self.hidden_size = hidden_size - - self.segment_size = segment_size - self.chunk_width_randomization = chunk_width_randomization - - self.input_embed = nn.Sequential( - ScaledLinear(feature_dim, input_size), - BasicNorm(input_size), - ActivationBalancer( - num_channels=input_size, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - ), - ) - - # dual-path RNN - self.row_rnn = nn.ModuleList([]) - self.col_rnn = nn.ModuleList([]) - for _ in range(num_blocks): - # intra-RNN is non-causal - self.row_rnn.append( - RNNEncoderLayer( - input_size, hidden_size, dropout=dropout, bidirectional=True - ) - ) - self.col_rnn.append( - RNNEncoderLayer( - input_size, hidden_size, dropout=dropout, bidirectional=False - ) - ) - - # output layer - self.out_embed = nn.Sequential( - ScaledLinear(input_size, output_size), - BasicNorm(output_size), - ActivationBalancer( - num_channels=output_size, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - ), - ) - - def forward(self, input): - # input shape: B, T, F - input = self.input_embed(input) - B, T, D = input.shape - - if self.chunk_width_randomization and self.training: - segment_size = random.randint(self.segment_size // 2, self.segment_size) - else: - segment_size = self.segment_size - input, rest = split_feature(input.transpose(1, 2), segment_size) - # input shape: batch, N, dim1, dim2 - # apply RNN on dim1 first and then dim2 - # output shape: B, output_size, dim1, dim2 - # input = input.to(device) - batch_size, _, dim1, dim2 = input.shape - output = input - for i in range(len(self.row_rnn)): - row_input = ( - output.permute(0, 3, 2, 1) - .contiguous() - .view(batch_size * dim2, dim1, -1) - ) # B*dim2, dim1, N - output = self.row_rnn[i](row_input) # B*dim2, dim1, H - output = ( - output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() - ) # B, N, dim1, dim2 - - col_input = ( - output.permute(0, 2, 3, 1) - .contiguous() - .view(batch_size * dim1, dim2, -1) - ) # B*dim1, dim2, N - output = self.col_rnn[i](col_input) # B*dim1, dim2, H - output = ( - output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() - ) # B, N, dim1, dim2 - - output = merge_feature(output, rest) - output = output.transpose(1, 2) - output = self.out_embed(output) - - # Apply ReLU to the output - output = torch.relu(output) - - return output - - -if __name__ == "__main__": - - model = DPRNN( - 80, - 256, - 256, - 160, - dropout=0.1, - num_blocks=4, - segment_size=32, - chunk_width_randomization=True, - ) - input = torch.randn(2, 1002, 80) - print(sum(p.numel() for p in model.parameters())) - print(model(input).shape) diff --git a/egs/libricss/SURT/dprnn_zipformer/encoder_interface.py b/egs/libricss/SURT/dprnn_zipformer/encoder_interface.py deleted file mode 120000 index 0c2673d46..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/export.py b/egs/libricss/SURT/dprnn_zipformer/export.py deleted file mode 100755 index f51f2a7ab..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/export.py +++ /dev/null @@ -1,306 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./dprnn_zipformer/export.py \ - --exp-dir ./dprnn_zipformer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./dprnn_zipformer/export.py \ - --exp-dir ./dprnn_zipformer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `dprnn_zipformer/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./dprnn_zipformer/decode.py \ - --exp-dir ./dprnn_zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_surt_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="dprnn_zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/libricss/SURT/dprnn_zipformer/joiner.py b/egs/libricss/SURT/dprnn_zipformer/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/model.py b/egs/libricss/SURT/dprnn_zipformer/model.py deleted file mode 100644 index 688e1e78d..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/model.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# Copyright 2023 Johns Hopkins University (author: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface - -from icefall.utils import add_sos - - -class SURT(nn.Module): - """It implements Streaming Unmixing and Recognition Transducer (SURT). - https://arxiv.org/abs/2011.13148 - """ - - def __init__( - self, - mask_encoder: nn.Module, - encoder: EncoderInterface, - joint_encoder_layer: Optional[nn.Module], - decoder: nn.Module, - joiner: nn.Module, - num_channels: int, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - mask_encoder: - It is the masking network. It generates a mask for each channel of the - encoder. These masks are applied to the input features, and then passed - to the transcription network. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - num_channels: - It is the number of channels that the input features will be split into. - In general, it should be equal to the maximum number of simultaneously - active speakers. For most real scenarios, using 2 channels is sufficient. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.mask_encoder = mask_encoder - self.encoder = encoder - self.joint_encoder_layer = joint_encoder_layer - self.decoder = decoder - self.joiner = joiner - self.num_channels = num_channels - - self.simple_am_proj = nn.Linear( - encoder_dim, - vocab_size, - ) - self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) - - self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), - nn.Linear(encoder_dim, vocab_size), - nn.LogSoftmax(dim=-1), - ) - - def forward_helper( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - reduction: str = "sum", - beam_size: int = 10, - use_double_scores: bool = False, - subsampling_factor: int = 1, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Compute transducer loss for one branch of the SURT model. - """ - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - - if self.joint_encoder_layer is not None: - encoder_out = self.joint_encoder_layer(encoder_out) - - # compute ctc log-probs - ctc_output = self.ctc_output(encoder_out) - - # For the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction=reduction, - ) - - # Compute ctc loss - supervision_segments = torch.stack( - ( - torch.arange(len(x_lens), device="cpu"), - torch.zeros_like(x_lens, device="cpu"), - torch.clone(x_lens).detach().cpu(), - ), - dim=1, - ).to(torch.int32) - # We need to sort supervision_segments in decreasing order of num_frames - indices = torch.argsort(supervision_segments[:, 2], descending=True) - supervision_segments = supervision_segments[indices] - - # Works with a BPE model - decoding_graph = k2.ctc_graph(y, modified=False, device=x.device) - dense_fsa_vec = k2.DenseFsaVec( - ctc_output, - supervision_segments, - allow_truncate=subsampling_factor - 1, - ) - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=beam_size, - reduction="none", - use_double_scores=use_double_scores, - ) - - return (simple_loss, pruned_loss, ctc_loss) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - reduction: str = "sum", - beam_size: int = 10, - use_double_scores: bool = False, - subsampling_factor: int = 1, - return_masks: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor of shape (N*num_channels, S). It contains the labels - of the N utterances. The labels are in the range [0, vocab_size). All - the channels are concatenated together one after another. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - beam_size: - The beam size used in CTC decoding. - use_double_scores: - If True, use double precision for CTC decoding. - subsampling_factor: - The subsampling factor of the model. It is used to compute the - supervision segments for CTC loss. - return_masks: - If True, return the masks as well as masked features. - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0), (x.size(), x_lens.size()) - - # Apply the mask encoder - B, T, F = x.shape - processed = self.mask_encoder(x) # B,T,F*num_channels - masks = processed.view(B, T, F, self.num_channels).unbind(dim=-1) - x_masked = [x * m for m in masks] - - # Recognition - # Stack the inputs along the batch axis - h = torch.cat(x_masked, dim=0) - h_lens = torch.cat([x_lens for _ in range(self.num_channels)], dim=0) - - simple_loss, pruned_loss, ctc_loss = self.forward_helper( - h, - h_lens, - y, - prune_range, - am_scale, - lm_scale, - reduction=reduction, - beam_size=beam_size, - use_double_scores=use_double_scores, - subsampling_factor=subsampling_factor, - ) - - # Chunks the outputs into 2 parts along batch axis and then stack them along a new axis. - simple_loss = torch.stack( - torch.chunk(simple_loss, self.num_channels, dim=0), dim=0 - ) - pruned_loss = torch.stack( - torch.chunk(pruned_loss, self.num_channels, dim=0), dim=0 - ) - ctc_loss = torch.stack(torch.chunk(ctc_loss, self.num_channels, dim=0), dim=0) - - if return_masks: - return (simple_loss, pruned_loss, ctc_loss, x_masked, masks) - else: - return (simple_loss, pruned_loss, ctc_loss, x_masked) diff --git a/egs/libricss/SURT/dprnn_zipformer/optim.py b/egs/libricss/SURT/dprnn_zipformer/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/pretrained.py b/egs/libricss/SURT/dprnn_zipformer/pretrained.py deleted file mode 100755 index 5f9468957..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/pretrained.py +++ /dev/null @@ -1,303 +0,0 @@ -#!/usr/bin/env python3 -""" -Usage: - -1. Download pre-trained models from -https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer - -2. - -./dprnn_zipformer/pretrained.py \ - --checkpoint /path/to/pretrained.pt \ - --tokens /path/to/data/lang_bpe_500/tokens.txt \ - /path/to/foo.wav -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_surt_model - -from icefall.utils import num_tokens - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens", - type=str, - required=True, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_surt_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - B, T, F = features.shape - processed = model.mask_encoder(features) # B,T,F*num_channels - masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) - x_masked = [features * m for m in masks] - - # Recognition - # Concatenate the inputs along the batch axis - h = torch.cat(x_masked, dim=0) - h_lens = feature_lengths.repeat(params.num_channels) - encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) - - if model.joint_encoder_layer is not None: - encoder_out = model.joint_encoder_layer(encoder_out) - - def _group_channels(hyps: List[str]) -> List[List[str]]: - """ - Currently we have a batch of size M*B, where M is the number of - channels and B is the batch size. We need to group the hypotheses - into B groups, each of which contains M hypotheses. - - Example: - hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2'] - _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']] - """ - assert len(hyps) == B * params.num_channels - out_hyps = [] - for i in range(B): - out_hyps.append(hyps[i::B]) - return out_hyps - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - hyps.append(token_ids_to_words(hyp)) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py deleted file mode 100644 index 4040a7b89..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/scaling.py +++ /dev/null @@ -1,1576 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import random -from typing import Optional, Tuple, Union - -import torch -import torch.backends.cudnn.rnn as rnn -import torch.nn as nn -from torch import _VF, Tensor - -from icefall.utils import is_jit_tracing - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - if sign_factor is None: - ctx.save_for_backward(xgt0, scale_factor) - else: - ctx.save_for_backward(xgt0, scale_factor, sign_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - if len(ctx.saved_tensors) == 3: - xgt0, scale_factor, sign_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - sign_factor = sign_factor.unsqueeze(-1) - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - else: - xgt0, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) - - if min_abs == 0.0: - below_threshold = 0.0 - else: - # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if - # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) - - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) - - return below_threshold - above_threshold - - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) - if min_positive == 0.0: - factor1 = 0.0 - else: - # 0 if proportion_positive >= min_positive, else can be - # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) - - if max_positive == 1.0: - factor2 = 0.0 - else: - # 0 if self.proportion_positive <= max_positive, else can be - # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) - sign_factor = factor1 - factor2 - # require min_positive != 0 or max_positive != 1: - assert not isinstance(sign_factor, float) - return sign_factor - - -class ActivationScaleBalancerFunction(torch.autograd.Function): - """ - This object is used in class ActivationBalancer when the user specified - min_positive=0, max_positive=1, so there are no constraints on the signs - of the activations and only the absolute value has a constraint. - """ - - @staticmethod - def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - ctx.save_for_backward(xgt0, sign_factor, scale_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - xgt0, sign_factor, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - sign_factor = sign_factor.unsqueeze(-1) - scale_factor = scale_factor.unsqueeze(-1) - - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -class RandomClampFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: - x_clamped = torch.clamp(x, min=min, max=max) - mask = torch.rand_like(x) < prob - ans = torch.where(mask, x_clamped, x) - if x.requires_grad: - ctx.save_for_backward(ans == x) - ctx.reflect = reflect - if reflect != 0.0: - ans = ans * (1.0 + reflect) - (x * reflect) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors - x_grad = ans_grad * is_same.to(ans_grad.dtype) - reflect = ctx.reflect - if reflect != 0.0: - x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return x_grad, None, None, None, None - - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): - return RandomClampFunction.apply(x, min, max, prob, reflect) - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class RandomGradFunction(torch.autograd.Function): - """ - Does nothing in forward pass; in backward pass, gets rid of very small grads using - randomized approach that preserves expectations (intended to reduce roundoff). - """ - - @staticmethod - def forward(ctx, x: Tensor, min_abs: float) -> Tensor: - ctx.min_abs = min_abs - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: - if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) - else: - return ans_grad, None - - -class RandomGrad(torch.nn.Module): - """ - Gets rid of very small gradients using an expectation-preserving method, intended to increase - accuracy of training when using amp (automatic mixed precision) - """ - - def __init__(self, min_abs: float = 5.0e-06): - super(RandomGrad, self).__init__() - self.min_abs = min_abs - - def forward(self, x: Tensor): - if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): - return x - else: - return RandomGradFunction.apply(x, self.min_abs) - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class GradientFilterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - batch_dim: int, # e.g., 1 - threshold: float, # e.g., 10.0 - *params: Tensor, # module parameters - ) -> Tuple[Tensor, ...]: - if x.requires_grad: - if batch_dim < 0: - batch_dim += x.ndim - ctx.batch_dim = batch_dim - ctx.threshold = threshold - return (x,) + params - - @staticmethod - def backward( - ctx, - x_grad: Tensor, - *param_grads: Tensor, - ) -> Tuple[Tensor, ...]: - eps = 1.0e-20 - dim = ctx.batch_dim - norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() - median_norm = norm_of_batch.median() - - cutoff = median_norm * ctx.threshold - inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) - mask = 1.0 / (inv_mask + eps) - x_grad = x_grad * mask - - avg_mask = 1.0 / (inv_mask.mean() + eps) - param_grads = [avg_mask * g for g in param_grads] - - return (x_grad, None, None) + tuple(param_grads) - - -class GradientFilter(torch.nn.Module): - """This is used to filter out elements that have extremely large gradients - in batch and the module parameters with soft masks. - - Args: - batch_dim (int): - The batch dimension. - threshold (float): - For each element in batch, its gradient will be - filtered out if the gradient norm is larger than - `grad_norm_threshold * median`, where `median` is the median - value of gradient norms of all elememts in batch. - """ - - def __init__(self, batch_dim: int = 1, threshold: float = 10.0): - super(GradientFilter, self).__init__() - self.batch_dim = batch_dim - self.threshold = threshold - - def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: - if torch.jit.is_scripting() or is_jit_tracing(): - return (x,) + params - else: - return GradientFilterFunction.apply( - x, - self.batch_dim, - self.threshold, - *params, - ) - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_min: float - eps_max: float - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_min: float = -3.0, - eps_max: float = 3.0, - ) -> None: - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.eps_min = eps_min - self.eps_max = eps_max - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - eps = self.eps - if self.training and random.random() < 0.25: - # with probability 0.25, in training mode, clamp eps between the min - # and max; this will encourage it to learn parameters within the - # allowed range by making parameters that are outside the allowed - # range noisy. - - # gradients to allow the parameter to get back into the allowed - # region if it happens to exit it. - eps = eps.clamp(min=self.eps_min, max=self.eps_max) - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() - ) ** -0.5 - return x * scales - - -class ScaledEmbedding(nn.Module): - r"""This is a modified version of nn.Embedding that introduces a learnable scale - on the parameters. Note: due to how we initialize it, it's best used with - schedulers like Noam that have a warmup period. - - It is a simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - initial_speed (float, optional): This affects how fast the parameter will - learn near the start of training; you can set it to a value less than - one if you suspect that a module is contributing to instability near - the start of training. Note: regardless of the use of this option, - it's best to use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - - """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - initial_speed: float = 1.0, - ) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" - elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters(initial_speed) - - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.1 / initial_speed - nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - scale = self.scale.exp() - if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) - else: - return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) - - def extra_repr(self) -> str: - # s = "{num_embeddings}, {embedding_dim}, scale={scale}" - s = "{num_embeddings}, {embedding_dim}" - if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" - if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" - if self.sparse is not False: - s += ", sparse=True" - return s.format(**self.__dict__) - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -class ScaledLSTM(nn.LSTM): - # See docs for ScaledLinear. - # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` - # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - grad_norm_threshold: float = 10.0, - **kwargs, - ): - super(ScaledLSTM, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self._scales_names = [] - self._scales = [] - self.batch_dim = 0 if self.batch_first else 1 - self.num_directions = 1 + int(self.bidirectional) - for name in self._flat_weights_names: - scale_name = name + "_scale" - self._scales_names.append(scale_name) - param = nn.Parameter(initial_scale.clone().detach()) - setattr(self, scale_name, param) - self._scales.append(param) - - self.grad_filter = GradientFilter( - batch_dim=self.batch_dim, threshold=grad_norm_threshold - ) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 - v = scale / std - for idx, name in enumerate(self._flat_weights_names): - if "weight" in name: - nn.init.uniform_(self._flat_weights[idx], -a, a) - with torch.no_grad(): - self._scales[idx] += torch.tensor(v).log() - elif "bias" in name: - nn.init.constant_(self._flat_weights[idx], 0.0) - - def _flatten_parameters(self, flat_weights) -> None: - """Resets parameter data pointer so that they can use faster code paths. - - Right now, this works only if the module is on the GPU and cuDNN is enabled. - Otherwise, it's a no-op. - - This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - """ - # Short-circuits if _flat_weights is only partially instantiated - if len(flat_weights) != len(self._flat_weights_names): - return - - for w in flat_weights: - if not isinstance(w, Tensor): - return - # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN - # or the tensors in flat_weights are of different dtypes - - first_fw = flat_weights[0] - dtype = first_fw.dtype - for fw in flat_weights: - if ( - not isinstance(fw.data, Tensor) - or not (fw.data.dtype == dtype) - or not fw.data.is_cuda - or not torch.backends.cudnn.is_acceptable(fw.data) - ): - return - - # If any parameters alias, we fall back to the slower, copying code path. This is - # a sufficient check, because overlapping parameter buffers that don't completely - # alias would break the assumptions of the uniqueness check in - # Module.named_parameters(). - unique_data_ptrs = set(p.data_ptr() for p in flat_weights) - if len(unique_data_ptrs) != len(flat_weights): - return - - with torch.cuda.device_of(first_fw): - - # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is - # an inplace operation on self._flat_weights - with torch.no_grad(): - if torch._use_cudnn_rnn_flatten_weight(): - num_weights = 4 if self.bias else 2 - if self.proj_size > 0: - num_weights += 1 - torch._cudnn_rnn_flatten_weight( - flat_weights, - num_weights, - self.input_size, - rnn.get_cudnn_mode(self.mode), - self.hidden_size, - self.proj_size, - self.num_layers, - self.batch_first, - bool(self.bidirectional), - ) - - def _get_flat_weights(self): - """Get scaled weights, and resets their data pointer.""" - flat_weights = [] - for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) - self._flatten_parameters(flat_weights) - return flat_weights - - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): - # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - # The change for calling `_VF.lstm()` is: - # self._flat_weights -> self._get_flat_weights() - if hx is None: - h_zeros = torch.zeros( - self.num_layers * self.num_directions, - input.size(self.batch_dim), - self.proj_size if self.proj_size > 0 else self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - c_zeros = torch.zeros( - self.num_layers * self.num_directions, - input.size(self.batch_dim), - self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - hx = (h_zeros, c_zeros) - - self.check_forward_args(input, hx, None) - - flat_weights = self._get_flat_weights() - input, *flat_weights = self.grad_filter(input, *flat_weights) - - result = _VF.lstm( - input, - hx, - flat_weights, - self.bias, - self.num_layers, - self.dropout, - self.training, - self.bidirectional, - self.batch_first, - ) - - output = result[0] - hidden = result[1:] - return output, hidden - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - sign_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_positive and max_positive - are violated. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - min_prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. Early in training we may use - higher probabilities than this; it will decay to this value. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, - ): - super(ActivationBalancer, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - self.min_prob = min_prob - self.sign_gain_factor = sign_gain_factor - self.scale_gain_factor = scale_gain_factor - - # count measures how many times the forward() function has been called. - # We occasionally sync this to a tensor called `count`, that exists to - # make sure it is synced to disk when we load and save the model. - self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): - return _no_op(x) - - count = self.cpu_count - self.cpu_count += 1 - - if random.random() < 0.01: - # Occasionally sync self.cpu_count with self.count. - # count affects the decay of 'prob'. don't do this on every iter, - # because syncing with the GPU is slow. - self.cpu_count = max(self.cpu_count, self.count.item()) - self.count.fill_(self.cpu_count) - - # the prob of doing some work exponentially decreases from 0.5 till it hits - # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) - - if random.random() < prob: - sign_gain_factor = 0.5 - if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) - else: - sign_factor = None - - scale_factor = _compute_scale_factor( - x.detach(), - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) - return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float - ) -> Tensor: - ctx.save_for_backward(x) - ctx.num_groups = num_groups - ctx.whitening_limit = whitening_limit - ctx.grad_scale = grad_scale - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, ctx.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) - - (metric - ctx.whitening_limit).relu().backward() - penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert whitening_limit >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - if isinstance(prob, float): - assert 0 < prob <= 1 - self.prob = prob - else: - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob < self.max_prob <= 1 - self.prob = self.max_prob - - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: - return _no_op(x) - else: - if hasattr(self, "min_prob") and random.random() < 0.25: - # occasionally switch between min_prob and max_prob, based on whether - # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): - # there would be a change to the grad. - self.prob = self.max_prob - else: - self.prob = self.min_prob - - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor): - ctx.y_shape = y.shape - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ( - ans_grad, - torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), - ) - - -def with_loss(x, y): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y) - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class MaxEig(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to discourage - that any given direction in activation space accounts for more than - a specified proportion of the covariance (e.g. 0.2). - - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - max_var_per_eig: the maximum proportion of the variance of the - features/channels, after mean subtraction, that can come from - any given eigenvalue. - min_prob: the minimum probability with which we apply this during any invocation - of forward(), assuming last time we applied the constraint it was - not active; supplied for speed. - scale: determines the scale with which we modify the gradients, relative - to the existing / unmodified gradients - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, - ): - super(MaxEig, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.scale = scale - assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels - self.max_var_per_eig = max_var_per_eig - - # we figure out the dominant direction using the power method: starting with - # a random vector, keep multiplying by the covariance and renormalizing. - with torch.no_grad(): - # arbitrary.. would use randn() but want to leave the rest of the model's - # random parameters unchanged for comparison - direction = torch.arange(num_channels).to(torch.float) - direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) - - self.min_prob = min_prob - # cur_prob is the current probability we'll use to apply the ActivationBalancer. - # We'll regress this towards prob, each tiem we try to apply it and it is not - # active. - self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - or torch.jit.is_tracing() - ): - return _no_op(x) - - with torch.cuda.amp.autocast(enabled=False): - eps = 1.0e-20 - orig_x = x - x = x.to(torch.float32) - with torch.no_grad(): - x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) - x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - - # ensure new direction is nonzero even if x == 0, by including `direction`. - self._set_direction(0.1 * self.max_eig_direction + new_direction) - - if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) - - if variance_proportion >= self.max_var_per_eig: - # The constraint is active. Note, we should quite rarely - # reach here, only near the beginning of training if we are - # starting to diverge, should this constraint be active. - cur_prob = self.cur_prob - self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) - else: - # let self.cur_prob exponentially approach self.min_prob, as - # long as the constraint is inactive. - self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob - return orig_x - - def _set_direction(self, direction: Tensor): - """ - Sets self.max_eig_direction to a normalized version of `direction` - """ - direction = direction.detach() - direction = direction / direction.norm() - direction_sum = direction.sum().item() - if direction_sum - direction_sum == 0: # no inf/nan - self.max_eig_direction[:] = direction - else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) - - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ - (num_frames, num_channels) = x.shape - assert num_channels > 1 and num_frames > 1 - assert prev_direction.shape == (num_channels,) - # `coeffs` are the coefficients of `prev_direction` in x. - # actually represent the coeffs up to a constant positive factor. - coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) - return cur_direction, coeffs - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - x_dtype = x.dtype - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.043637 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad, atol=1.0e-02) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - - -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - min_prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = (1.2 - (-0.043637)) / 255.0 - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_softmax() - _test_whiten() - _test_max_eig() - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py b/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py deleted file mode 100755 index 148cafd4b..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/train.py +++ /dev/null @@ -1,1449 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# 2023 Johns Hopkins University (author: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -cd egs/libricss/SURT -./prepare.sh - -./dprnn_zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir dprnn_zipformer/exp \ - --max-duration 300 - -# For mix precision training: - -./dprnn_zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir dprnn_zipformer/exp \ - --max-duration 550 -""" - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriCssAsrDataModule -from decoder import Decoder -from dprnn import DPRNN -from einops.layers.torch import Rearrange -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import LOG_EPSILON, fix_random_seed -from model import SURT -from optim import Eden, ScaledAdam -from scaling import ScaledLSTM -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-mask-encoder-layers", - type=int, - default=4, - help="Number of layers in the DPRNN based mask encoder.", - ) - - parser.add_argument( - "--mask-encoder-dim", - type=int, - default=256, - help="Hidden dimension of the LSTM blocks in DPRNN.", - ) - - parser.add_argument( - "--mask-encoder-segment-size", - type=int, - default=32, - help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " - "decode-chunk-length of the zipformer encoder.", - ) - - parser.add_argument( - "--chunk-width-randomization", - type=bool, - default=False, - help="Whether to randomize the chunk width in DPRNN.", - ) - - # Zipformer config is based on: - # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,2,2,2", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="768,768,768,768,768", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="256,256,256,256,256", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="192,192,192,192,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--use-joint-encoder-layer", - type=str, - default="lstm", - choices=["linear", "lstm", "none"], - help="Whether to use a joint layer to combine all branches.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conv_lstm_transducer_stateless_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--model-init-ckpt", - type=str, - default=None, - help="""The model checkpoint to initialize the model (either full or part). - If not specified, the model is randomly initialized. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.004, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--heat-loss-scale", - type=float, - default=0.0, - help="Scale for HEAT loss on separated sources.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=1, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 2000, - # parameters for SURT - "num_channels": 2, - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed - # parameters for Noam - "model_warm_step": 5000, # arg given to model, not for lrate - # parameters for ctc loss - "beam_size": 10, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def get_mask_encoder_model(params: AttributeDict) -> nn.Module: - mask_encoder = DPRNN( - feature_dim=params.feature_dim, - input_size=params.mask_encoder_dim, - hidden_size=params.mask_encoder_dim, - output_size=params.feature_dim * params.num_channels, - segment_size=params.mask_encoder_segment_size, - num_blocks=params.num_mask_encoder_layers, - chunk_width_randomization=params.chunk_width_randomization, - ) - return mask_encoder - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_joint_encoder_layer(params: AttributeDict) -> nn.Module: - class TakeFirst(nn.Module): - def forward(self, x): - return x[0] - - if params.use_joint_encoder_layer == "linear": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - nn.Linear( - params.num_channels * encoder_dim, params.num_channels * encoder_dim - ), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "lstm": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - ScaledLSTM( - input_size=params.num_channels * encoder_dim, - hidden_size=params.num_channels * encoder_dim, - num_layers=1, - bias=True, - batch_first=True, - dropout=0.0, - bidirectional=False, - ), - TakeFirst(), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "none": - joint_layer = None - else: - raise ValueError( - f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}" - ) - return joint_layer - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_surt_model( - params: AttributeDict, -) -> nn.Module: - mask_encoder = get_mask_encoder_model(params) - encoder = get_encoder_model(params) - joint_layer = get_joint_encoder_layer(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = SURT( - mask_encoder=mask_encoder, - encoder=encoder, - joint_encoder_layer=joint_layer, - decoder=decoder, - joiner=joiner, - num_channels=params.num_channels, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_heat_loss(x_masked, batch, num_channels=2) -> Tensor: - """ - Compute HEAT loss for separated sources using the output of mask encoder. - Args: - x_masked: - The output of mask encoder. It is a tensor of shape (B, T, C). - batch: - A batch of data. See `lhotse.dataset.K2SurtDatasetWithSources()` - for the content in it. - num_channels: - The number of output branches in the SURT model. - """ - B, T, D = x_masked[0].shape - device = x_masked[0].device - - # Create training targets for each channel. - targets = [] - for i in range(num_channels): - target = torch.ones_like(x_masked[i]) * LOG_EPSILON - targets.append(target) - - source_feats = batch["source_feats"] - source_boundaries = batch["source_boundaries"] - input_lens = batch["input_lens"].to(device) - # Assign sources to channels based on the HEAT criteria - for b in range(B): - cut_source_feats = source_feats[b] - cut_source_boundaries = source_boundaries[b] - last_seg_end = [0 for _ in range(num_channels)] - for source_feat, (start, end) in zip(cut_source_feats, cut_source_boundaries): - assigned = False - for i in range(num_channels): - if start >= last_seg_end[i]: - targets[i][b, start:end, :] += source_feat.to(device) - last_seg_end[i] = max(end, last_seg_end[i]) - assigned = True - break - if not assigned: - min_end_channel = last_seg_end.index(min(last_seg_end)) - targets[min_end_channel][b, start:end, :] += source_feat - last_seg_end[min_end_channel] = max(end, last_seg_end[min_end_channel]) - - # Get padding mask based on input lengths - pad_mask = torch.arange(T, device=device).expand(B, T) > input_lens.unsqueeze(1) - pad_mask = pad_mask.unsqueeze(-1) - - # Compute masked loss for each channel - losses = torch.zeros((num_channels, B, T, D), device=device) - for i in range(num_channels): - loss = nn.functional.mse_loss(x_masked[i], targets[i], reduction="none") - # Apply padding mask to loss - loss.masked_fill_(pad_mask, 0) - losses[i] = loss - - # loss: C x B x T x D. pad_mask: B x T x 1 - # We want to compute loss for each item in the batch. Each item has loss given - # by the sum over C, and average over T and D. For T, we need to use the padding. - loss = losses.sum(0).mean(-1).sum(-1) / batch["input_lens"].to(device) - return loss - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"].to(device) - feature_lens = batch["input_lens"].to(device) - - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - - # The dataloader returns text as a list of cuts, each of which is a list of channel - # text. We flatten this to a list where all channels are together, i.e., it looks like - # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. - text = [val for tup in zip(*batch["text"]) for val in tup] - assert len(text) == len(feature) * params.num_channels - - # Convert all channel texts to token IDs and create a ragged tensor. - y = sp.encode(text, out_type=int) - y = k2.RaggedTensor(y).to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.model_warm_step - - with torch.set_grad_enabled(is_training): - (simple_loss, pruned_loss, ctc_loss, x_masked) = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - reduction="none", - subsampling_factor=params.subsampling_factor, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - ctc_loss_is_finite = torch.isfinite(ctc_loss) - - # Compute HEAT loss - if is_training and params.heat_loss_scale > 0.0: - heat_loss = compute_heat_loss( - x_masked, batch, num_channels=params.num_channels - ) - else: - heat_loss = torch.tensor(0.0, device=device) - - heat_loss_is_finite = torch.isfinite(heat_loss) - is_finite = ( - simple_loss_is_finite - & pruned_loss_is_finite - & ctc_loss_is_finite - & heat_loss_is_finite - ) - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_losses: {simple_loss}\n" - f"pruned_losses: {pruned_loss}\n" - f"ctc_losses: {ctc_loss}\n" - f"heat_losses: {heat_loss}\n" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - ctc_loss = ctc_loss[ctc_loss_is_finite] - heat_loss = heat_loss[heat_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if ( - torch.all(~simple_loss_is_finite) - or torch.all(~pruned_loss_is_finite) - or torch.all(~ctc_loss_is_finite) - or torch.all(~heat_loss_is_finite) - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss_sum = simple_loss.sum() - pruned_loss_sum = pruned_loss.sum() - ctc_loss_sum = ctc_loss.sum() - heat_loss_sum = heat_loss.sum() - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = ( - simple_loss_scale * simple_loss_sum - + pruned_loss_scale * pruned_loss_sum - + params.ctc_loss_scale * ctc_loss_sum - + params.heat_loss_scale * heat_loss_sum - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss_sum.detach().cpu().item() - info["pruned_loss"] = pruned_loss_sum.detach().cpu().item() - if params.ctc_loss_scale > 0.0: - info["ctc_loss"] = ctc_loss_sum.detach().cpu().item() - if params.heat_loss_scale > 0.0: - info["heat_loss"] = heat_loss_sum.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - train_dl_warmup: Optional[torch.utils.data.DataLoader], - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - train_dl_warmup: - Dataloader for the training dataset with 2 speakers. This is used during the - warmup stage. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - torch.cuda.empty_cache() - model.train() - - tot_loss = MetricsTracker() - - iter_train = iter(train_dl) - iter_train_warmup = iter(train_dl_warmup) if train_dl_warmup is not None else None - - batch_idx = 0 - - while True: - # We first sample a batch from the main dataset. This is because we want to - # make sure all epochs have the same number of batches. - try: - batch = next(iter_train) - except StopIteration: - break - - # If we are in warmup stage, get the batch from the warmup dataset. - if ( - params.batch_idx_train <= params.model_warm_step - and iter_train_warmup is not None - ): - try: - batch = next(iter_train_warmup) - except StopIteration: - iter_train_warmup = iter(train_dl_warmup) - batch = next(iter_train_warmup) - - batch_idx += 1 - - params.batch_idx_train += 1 - batch_size = batch["inputs"].shape[0] - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - - if checkpoints is None and params.model_init_ckpt is not None: - logging.info( - f"Initializing model with checkpoint from {params.model_init_ckpt}" - ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) - model.load_state_dict(init_ckpt["model"], strict=False) - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - libricss = LibriCssAsrDataModule(args) - - train_cuts = libricss.lsmix_cuts(rvb_affix="comb", type_affix="full", sources=True) - train_cuts_ov40 = libricss.lsmix_cuts( - rvb_affix="comb", type_affix="ov40", sources=True - ) - dev_cuts = libricss.libricss_cuts(split="dev", type="sdm") - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = libricss.train_dataloaders( - train_cuts, - sampler_state_dict=sampler_state_dict, - ) - train_dl_ov40 = libricss.train_dataloaders(train_cuts_ov40) - valid_dl = libricss.valid_dataloaders(dev_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - train_dl_warmup=train_dl_ov40, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = [sp.encode(text_ch) for text_ch in batch["text"]] - num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y] - logging.info(f"num tokens: {num_tokens}") - - -def main(): - parser = get_parser() - LibriCssAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py deleted file mode 100755 index 8c37430ec..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py +++ /dev/null @@ -1,1342 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES=0 - -./dprnn_zipformer/train.py \ - --world-size 1 \ - --num-epochs 15 \ - --start-epoch 1 \ - --exp-dir dprnn_zipformer/exp \ - --max-duration 300 - -# For mix precision training: - -./dprnn_zipformer/train.py \ - --world-size 1 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir dprnn_zipformer/exp \ - --max-duration 550 -""" - -import argparse -import copy -import logging -import warnings -from itertools import chain -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriCssAsrDataModule -from decoder import Decoder -from dprnn import DPRNN -from einops.layers.torch import Rearrange -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import LOG_EPSILON, fix_random_seed -from model import SURT -from optim import Eden, ScaledAdam -from scaling import ScaledLinear, ScaledLSTM -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-mask-encoder-layers", - type=int, - default=4, - help="Number of layers in the DPRNN based mask encoder.", - ) - - parser.add_argument( - "--mask-encoder-dim", - type=int, - default=256, - help="Hidden dimension of the LSTM blocks in DPRNN.", - ) - - parser.add_argument( - "--mask-encoder-segment-size", - type=int, - default=32, - help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " - "decode-chunk-length of the zipformer encoder.", - ) - - parser.add_argument( - "--chunk-width-randomization", - type=bool, - default=False, - help="Whether to randomize the chunk width in DPRNN.", - ) - - # Zipformer config is based on: - # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,2,2,2", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="768,768,768,768,768", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="256,256,256,256,256", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="192,192,192,192,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--use-joint-encoder-layer", - type=str, - default="lstm", - choices=["linear", "lstm", "none"], - help="Whether to use a joint layer to combine all branches.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=15, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conv_lstm_transducer_stateless_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--model-init-ckpt", - type=str, - default=None, - help="""The model checkpoint to initialize the model (either full or part). - If not specified, the model is randomly initialized. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.0004, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=1000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=2, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=1000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=5, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 200, - "valid_interval": 100, - # parameters for SURT - "num_channels": 2, - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed - # parameters for Noam - "model_warm_step": 5000, # arg given to model, not for lrate - # parameters for ctc loss - "beam_size": 10, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def get_mask_encoder_model(params: AttributeDict) -> nn.Module: - mask_encoder = DPRNN( - feature_dim=params.feature_dim, - input_size=params.mask_encoder_dim, - hidden_size=params.mask_encoder_dim, - output_size=params.feature_dim * params.num_channels, - segment_size=params.mask_encoder_segment_size, - num_blocks=params.num_mask_encoder_layers, - chunk_width_randomization=params.chunk_width_randomization, - ) - return mask_encoder - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_joint_encoder_layer(params: AttributeDict) -> nn.Module: - class TakeFirst(nn.Module): - def forward(self, x): - return x[0] - - if params.use_joint_encoder_layer == "linear": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - nn.Linear( - params.num_channels * encoder_dim, params.num_channels * encoder_dim - ), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "lstm": - encoder_dim = int(params.encoder_dims.split(",")[-1]) - joint_layer = nn.Sequential( - Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), - ScaledLSTM( - input_size=params.num_channels * encoder_dim, - hidden_size=params.num_channels * encoder_dim, - num_layers=1, - bias=True, - batch_first=True, - dropout=0.0, - bidirectional=False, - ), - TakeFirst(), - nn.ReLU(), - Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), - ) - elif params.use_joint_encoder_layer == "none": - joint_layer = None - else: - raise ValueError( - f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}" - ) - return joint_layer - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_surt_model( - params: AttributeDict, -) -> nn.Module: - mask_encoder = get_mask_encoder_model(params) - encoder = get_encoder_model(params) - joint_layer = get_joint_encoder_layer(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = SURT( - mask_encoder=mask_encoder, - encoder=encoder, - joint_encoder_layer=joint_layer, - decoder=decoder, - joiner=joiner, - num_channels=params.num_channels, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"].to(device) - feature_lens = batch["input_lens"].to(device) - - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - - # The dataloader returns text as a list of cuts, each of which is a list of channel - # text. We flatten this to a list where all channels are together, i.e., it looks like - # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. - text = [val for tup in zip(*batch["text"]) for val in tup] - assert len(text) == len(feature) * params.num_channels - - # Convert all channel texts to token IDs and create a ragged tensor. - y = sp.encode(text, out_type=int) - y = k2.RaggedTensor(y).to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.model_warm_step - - with torch.set_grad_enabled(is_training): - (simple_loss, pruned_loss, ctc_loss, x_masked) = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - reduction="none", - subsampling_factor=params.subsampling_factor, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - ctc_loss_is_finite = torch.isfinite(ctc_loss) - - is_finite = simple_loss_is_finite & pruned_loss_is_finite & ctc_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_losses: {simple_loss}\n" - f"pruned_losses: {pruned_loss}\n" - f"ctc_losses: {ctc_loss}\n" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - ctc_loss = ctc_loss[ctc_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if ( - torch.all(~simple_loss_is_finite) - or torch.all(~pruned_loss_is_finite) - or torch.all(~ctc_loss_is_finite) - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss_sum = simple_loss.sum() - pruned_loss_sum = pruned_loss.sum() - ctc_loss_sum = ctc_loss.sum() - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = ( - simple_loss_scale * simple_loss_sum - + pruned_loss_scale * pruned_loss_sum - + params.ctc_loss_scale * ctc_loss_sum - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss_sum.detach().cpu().item() - info["pruned_loss"] = pruned_loss_sum.detach().cpu().item() - if params.ctc_loss_scale > 0.0: - info["ctc_loss"] = ctc_loss_sum.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - train_dl_warmup: - Dataloader for the training dataset with 2 speakers. This is used during the - warmup stage. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - torch.cuda.empty_cache() - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = batch["inputs"].shape[0] - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - - if checkpoints is None and params.model_init_ckpt is not None: - logging.info( - f"Initializing model with checkpoint from {params.model_init_ckpt}" - ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) - model.load_state_dict(init_ckpt["model"], strict=True) - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - libricss = LibriCssAsrDataModule(args) - - train_cuts_ihm = libricss.libricss_cuts(split="dev", type="ihm-mix") - train_cuts_sdm = libricss.libricss_cuts(split="dev", type="sdm") - train_cuts = train_cuts_ihm + train_cuts_sdm - - # This will create 2 copies of the sessions with different segmentation - train_cuts = train_cuts.trim_to_supervision_groups( - max_pause=0.1 - ) + train_cuts.trim_to_supervision_groups(max_pause=0.5) - dev_cuts = libricss.libricss_cuts(split="dev", type="sdm") - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = libricss.train_dataloaders( - train_cuts, - sampler_state_dict=sampler_state_dict, - return_sources=False, - strict=False, - ) - valid_dl = libricss.valid_dataloaders(dev_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = [sp.encode(text_ch) for text_ch in batch["text"]] - num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y] - logging.info(f"num tokens: {num_tokens}") - - -def main(): - parser = get_parser() - LibriCssAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_zipformer/zipformer.py b/egs/libricss/SURT/dprnn_zipformer/zipformer.py deleted file mode 120000 index ec183baa7..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/libricss/SURT/heat.png b/egs/libricss/SURT/heat.png deleted file mode 100644 index ac7ecfff4..000000000 Binary files a/egs/libricss/SURT/heat.png and /dev/null differ diff --git a/egs/libricss/SURT/local/add_source_feats.py b/egs/libricss/SURT/local/add_source_feats.py deleted file mode 100755 index c9775561f..000000000 --- a/egs/libricss/SURT/local/add_source_feats.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file adds source features as temporal arrays to the mixture manifests. -It looks for manifests in the directory data/manifests. -""" -import logging -from pathlib import Path - -import numpy as np -from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy -from tqdm import tqdm - - -def add_source_feats(num_jobs=1): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - for type_affix in ["full", "ov40"]: - logging.info(f"Adding source features for {type_affix}") - mixed_name_clean = f"train_clean_{type_affix}" - mixed_name_rvb = f"train_rvb_{type_affix}" - - logging.info("Reading mixed cuts") - mixed_cuts_clean = load_manifest_lazy( - src_dir / f"cuts_{mixed_name_clean}.jsonl.gz" - ) - mixed_cuts_rvb = load_manifest_lazy(src_dir / f"cuts_{mixed_name_rvb}.jsonl.gz") - - logging.info("Reading source cuts") - source_cuts = load_manifest(src_dir / "librispeech_cuts_train_trimmed.jsonl.gz") - - logging.info("Adding source features to the mixed cuts") - with tqdm() as pbar, CutSet.open_writer( - src_dir / f"cuts_{mixed_name_clean}_sources.jsonl.gz" - ) as cut_writer_clean, CutSet.open_writer( - src_dir / f"cuts_{mixed_name_rvb}_sources.jsonl.gz" - ) as cut_writer_rvb, LilcomChunkyWriter( - output_dir / f"feats_train_{type_affix}_sources" - ) as source_feat_writer: - for cut_clean, cut_rvb in zip(mixed_cuts_clean, mixed_cuts_rvb): - assert cut_rvb.id == cut_clean.id + "_rvb" - # Create source_feats and source_feat_offsets - # (See `lhotse.datasets.K2SurtDataset` for details) - source_feats = [] - source_feat_offsets = [] - cur_offset = 0 - for sup in sorted( - cut_clean.supervisions, key=lambda s: (s.start, s.speaker) - ): - source_cut = source_cuts[sup.id] - source_feats.append(source_cut.load_features()) - source_feat_offsets.append(cur_offset) - cur_offset += source_cut.num_frames - cut_clean.source_feats = source_feat_writer.store_array( - cut_clean.id, np.concatenate(source_feats, axis=0) - ) - cut_clean.source_feat_offsets = source_feat_offsets - cut_writer_clean.write(cut_clean) - cut_rvb.source_feats = cut_clean.source_feats - cut_rvb.source_feat_offsets = cut_clean.source_feat_offsets - cut_writer_rvb.write(cut_rvb) - pbar.update(1) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - add_source_feats() diff --git a/egs/libricss/SURT/local/compute_fbank_libricss.py b/egs/libricss/SURT/local/compute_fbank_libricss.py deleted file mode 100755 index afd66899c..000000000 --- a/egs/libricss/SURT/local/compute_fbank_libricss.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the LibriCSS dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" -import logging -from pathlib import Path - -import pyloudnorm as pyln -import torch -import torch.multiprocessing -from lhotse import LilcomChunkyWriter, load_manifest_lazy -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_libricss(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - cuts_ihm_mix = load_manifest_lazy( - src_dir / "libricss-ihm-mix_segments_all.jsonl.gz" - ) - cuts_sdm = load_manifest_lazy(src_dir / "libricss-sdm_segments_all.jsonl.gz") - - for name, cuts in [("ihm-mix", cuts_ihm_mix), ("sdm", cuts_sdm)]: - dev_cuts = cuts.filter(lambda c: "session0" in c.id) - test_cuts = cuts.filter(lambda c: "session0" not in c.id) - - # If SDM cuts, apply loudness normalization - if name == "sdm": - dev_cuts = dev_cuts.normalize_loudness(target=-23.0) - test_cuts = test_cuts.normalize_loudness(target=-23.0) - - logging.info(f"Extracting fbank features for {name} dev cuts") - _ = dev_cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"libricss-{name}_feats_dev", - manifest_path=src_dir / f"cuts_dev_libricss-{name}.jsonl.gz", - batch_duration=500, - num_workers=2, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Extracting fbank features for {name} test cuts") - _ = test_cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"libricss-{name}_feats_test", - manifest_path=src_dir / f"cuts_test_libricss-{name}.jsonl.gz", - batch_duration=2000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_libricss() diff --git a/egs/libricss/SURT/local/compute_fbank_librispeech.py b/egs/libricss/SURT/local/compute_fbank_librispeech.py deleted file mode 100755 index 5c8aece9c..000000000 --- a/egs/libricss/SURT/local/compute_fbank_librispeech.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the LibriSpeech dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -from pathlib import Path - -import torch -from lhotse import CutSet, LilcomChunkyWriter -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_librispeech(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_mel_bins = 80 - - dataset_parts = ( - "train-clean-100", - "train-clean-360", - "train-other-500", - ) - prefix = "librispeech" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=16000), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - manifest_path=f"{src_dir}/{cuts_filename}", - batch_duration=4000, - num_workers=2, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_librispeech() diff --git a/egs/libricss/SURT/local/compute_fbank_lsmix.py b/egs/libricss/SURT/local/compute_fbank_lsmix.py deleted file mode 100755 index da42f8ba1..000000000 --- a/egs/libricss/SURT/local/compute_fbank_lsmix.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the synthetically mixed LibriSpeech -train and dev sets. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" -import logging -import random -import warnings -from pathlib import Path - -import torch -import torch.multiprocessing -from lhotse import LilcomChunkyWriter, load_manifest -from lhotse.cut import MixedCut, MixTrack, MultiCut -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached -from lhotse.utils import fix_random_seed, uuid4 - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_lsmix(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - manifests = read_manifests_if_cached( - dataset_parts=["train_clean_full", "train_clean_ov40"], - types=["cuts"], - output_dir=src_dir, - prefix="lsmix", - suffix="jsonl.gz", - lazy=True, - ) - - cs = {} - cs["clean_full"] = manifests["train_clean_full"]["cuts"] - cs["clean_ov40"] = manifests["train_clean_ov40"]["cuts"] - - # only uses RIRs and noises from REVERB challenge - real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter( - lambda r: "RVB2014" in r.id - ) - noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter( - lambda r: "RVB2014" in r.id - ) - - # Apply perturbation to the training cuts - logging.info("Applying perturbation to the training cuts") - cs["rvb_full"] = cs["clean_full"].map( - lambda c: augment( - c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True - ) - ) - cs["rvb_ov40"] = cs["clean_ov40"].map( - lambda c: augment( - c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True - ) - ) - - for type_affix in ["full", "ov40"]: - for rvb_affix in ["clean", "rvb"]: - logging.info( - f"Extracting fbank features for {type_affix} {rvb_affix} training cuts" - ) - cuts = cs[f"{rvb_affix}_{type_affix}"] - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - _ = cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir - / f"lsmix_feats_train_{rvb_affix}_{type_affix}", - manifest_path=src_dir - / f"cuts_train_{rvb_affix}_{type_affix}.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - -def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False): - """ - Given a mixed cut, this function optionally applies the following augmentations: - - Perturbing the SNRs of the tracks (in range [-5, 5] dB) - - Reverberation using a randomly selected RIR - - Adding noise - - Perturbing the loudness (in range [-20, -25] dB) - """ - out_cut = cut.drop_features() - - # Perturb the SNRs (optional) - if perturb_snr: - snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))] - for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)): - if i == 0: - # Skip the first track since it is the reference - continue - track.snr = snr - - # Reverberate the cut (optional) - if rirs is not None: - # Select an RIR at random - rir = random.choice(rirs) - # Select a channel at random - rir_channel = random.choice(list(range(rir.num_channels))) - # Reverberate the cut - out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel]) - - # Add noise (optional) - if noises is not None: - # Select a noise recording at random - noise = random.choice(noises).to_cut() - if isinstance(noise, MultiCut): - noise = noise.to_mono()[0] - # Select an SNR at random - snr = random.uniform(10, 30) - # Repeat the noise to match the duration of the cut - noise = repeat_cut(noise, out_cut.duration) - out_cut = MixedCut( - id=out_cut.id, - tracks=[ - MixTrack(cut=out_cut, type="MixedCut"), - MixTrack(cut=noise, type="DataCut", snr=snr), - ], - ) - - # Perturb the loudness (optional) - if perturb_loudness: - target_loudness = random.uniform(-20, -25) - out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True) - return out_cut - - -def repeat_cut(cut, duration): - while cut.duration < duration: - cut = cut.mix(cut, offset_other_by=cut.duration) - return cut.truncate(duration=duration) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - fix_random_seed(42) - compute_fbank_lsmix() diff --git a/egs/libricss/SURT/local/compute_fbank_musan.py b/egs/libricss/SURT/local/compute_fbank_musan.py deleted file mode 100755 index 1fcf951f9..000000000 --- a/egs/libricss/SURT/local/compute_fbank_musan.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -from pathlib import Path - -import torch -from lhotse import CutSet, LilcomChunkyWriter, combine -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = src_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - # create chunks of Musan with duration 5 - 10 seconds - _ = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / "musan_feats", - manifest_path=musan_cuts_path, - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - ) - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() diff --git a/egs/libricss/SURT/prepare.sh b/egs/libricss/SURT/prepare.sh deleted file mode 100755 index b2d37f949..000000000 --- a/egs/libricss/SURT/prepare.sh +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/librispeech -# You can find audio and transcripts for LibriSpeech in this path. -# -# - $dl_dir/libricss -# You can find audio and transcripts for LibriCSS in this path. -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -# -# - $dl_dir/rirs_noises -# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/. -# -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data -vocab_size=500 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/librispeech, - # you can create a symlink - # - # ln -sfv /path/to/librispeech $dl_dir/librispeech - # - if [ ! -d $dl_dir/librispeech ]; then - lhotse download librispeech $dl_dir/librispeech - fi - - # If you have pre-downloaded it to /path/to/libricss, - # you can create a symlink - # - # ln -sfv /path/to/libricss $dl_dir/libricss - # - if [ ! -d $dl_dir/libricss ]; then - lhotse download libricss $dl_dir/libricss - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi - - # If you have pre-downloaded it to /path/to/rirs_noises, - # you can create a symlink - # - # ln -sfv /path/to/rirs_noises $dl_dir/ - # - if [ ! -d $dl_dir/rirs_noises ]; then - lhotse download rir-noise $dl_dir/rirs_noises - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare LibriSpeech manifests" - # We assume that you have downloaded the LibriSpeech corpus - # to $dl_dir/librispeech. We perform text normalization for the transcripts. - # NOTE: Alignments are required for this recipe. - mkdir -p data/manifests - - log "This recipe uses mfa alignment for trimming" - if [ ! -d $dl_dir/libri_alignments/LibriSpeech ]; then - log "No alignment provided. please refer to ../../librispeech/ASR/add_alignments.sh \n \ - for mfa alignments. Once you have downloaded and unzipped the .zip file containing \n \ - all alignments, the folder should be renamed to libri_alignments and moved to your $dl_dir ." - exit 0 - fi - - lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \ - -j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/ -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare LibriCSS manifests" - # We assume that you have downloaded the LibriCSS corpus - # to $dl_dir/libricss. We perform text normalization for the transcripts. - mkdir -p data/manifests - for mic in sdm ihm-mix; do - lhotse prepare libricss --type $mic --segmented $dl_dir/libricss data/manifests/ - done -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare musan manifest and RIRs" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - - # We assume that you have downloaded the RIRS_NOISES corpus - # to $dl_dir/rirs_noises - lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises/RIRS_NOISES data/manifests -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Extract features for LibriSpeech, trim to alignments, and shuffle the cuts" - # python local/compute_fbank_librispeech.py - lhotse combine data/manifests/librispeech_cuts_train* data/manifests/librispeech_cuts_train_all.jsonl.gz - lhotse cut trim-to-alignments --type word --max-pause 0.2 \ - data/manifests/librispeech_cuts_train_all.jsonl.gz \ - data/manifests/librispeech_cuts_train_all_trimmed.jsonl.gz - cat <(gunzip -c data/manifests/librispeech_cuts_train_all_trimmed.jsonl.gz) | \ - shuf | gzip -c > data/manifests/librispeech_cuts_train_trimmed.jsonl.gz -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Create simulated mixtures from LibriSpeech (train and dev). This may take a while." - # We create a high overlap set which will be used during the model warmup phase, and a - # full training set that will be used for the subsequent training. - - gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\ - grep -v "0L" | grep -v "OV10" |\ - gzip -c > data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz - - gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\ - grep "OV40" |\ - gzip -c > data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz - - # Warmup mixtures (100k) based on high overlap (OV40) - log "Generating 100k anechoic train mixtures for warmup" - lhotse workflows simulate-meetings \ - --method conversational \ - --fit-to-supervisions data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz \ - --num-meetings 100000 \ - --num-speakers-per-meeting 2,3 \ - --max-duration-per-speaker 15.0 \ - --max-utterances-per-speaker 3 \ - --seed 1234 \ - --num-jobs 4 \ - data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ - data/manifests/lsmix_cuts_train_clean_ov40.jsonl.gz - - # Full training set (2,3 speakers) anechoic - log "Generating anechoic set (full)" - lhotse workflows simulate-meetings \ - --method conversational \ - --fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \ - --num-repeats 1 \ - --num-speakers-per-meeting 2,3 \ - --max-duration-per-speaker 15.0 \ - --max-utterances-per-speaker 3 \ - --seed 1234 \ - --num-jobs 4 \ - data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ - data/manifests/lsmix_cuts_train_clean_full.jsonl.gz -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Compute fbank features for musan" - mkdir -p data/fbank - python local/compute_fbank_musan.py -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compute fbank features for simulated Libri-mix" - mkdir -p data/fbank - python local/compute_fbank_lsmix.py -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Add source feats to mixtures (useful for auxiliary tasks)" - python local/add_source_feats.py - - log "Combining lsmix-clean and lsmix-rvb" - for type in full ov40; do - cat <(gunzip -c data/manifests/cuts_train_clean_${type}_sources.jsonl.gz) \ - <(gunzip -c data/manifests/cuts_train_rvb_${type}_sources.jsonl.gz) |\ - shuf | gzip -c > data/manifests/cuts_train_comb_${type}_sources.jsonl.gz - done -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compute fbank features for LibriCSS" - mkdir -p data/fbank - python local/compute_fbank_libricss.py -fi - -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Download LibriSpeech BPE model from HuggingFace." - mkdir -p data/lang_bpe_500 - pushd data/lang_bpe_500 - wget https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/resolve/main/data/lang_bpe_500/bpe.model - popd -fi diff --git a/egs/libricss/SURT/shared b/egs/libricss/SURT/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/libricss/SURT/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/libricss/SURT/surt.png b/egs/libricss/SURT/surt.png deleted file mode 100644 index fcc8119d4..000000000 Binary files a/egs/libricss/SURT/surt.png and /dev/null differ diff --git a/egs/libriheavy/ASR/README.md b/egs/libriheavy/ASR/README.md deleted file mode 100644 index 2498d017f..000000000 --- a/egs/libriheavy/ASR/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context - -Libriheavy is a labeled version of [Librilight](https://arxiv.org/pdf/1912.07875.pdf). Please refer to our repository [k2-fsa/libriheavy](https://github.com/k2-fsa/libriheavy) for more details. We also have a paper: *Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context*, [Preprint available on arxiv](https://arxiv.org/abs/2309.08105). - - -See [RESULTS](./RESULTS.md) for the results for icefall recipes. diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md deleted file mode 100644 index 513bbf72e..000000000 --- a/egs/libriheavy/ASR/RESULTS.md +++ /dev/null @@ -1,315 +0,0 @@ -# Results - -## zipformer (zipformer + pruned stateless transducer) - -See for more details. - -[zipformer](./zipformer) - -### Non-streaming - -#### Training on normalized text, i.e. Upper case without punctuation - -##### normal-scaled model, number of model parameters: 65805511, i.e., 65.81 M - -You can find a pretrained model, training logs at: - - -Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), -exp_small_subset(small set). - -Results of models: - -| training set | decoding method | librispeech clean | librispeech other | libriheavy clean | libriheavy other | comment | -|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| -| small | greedy search | 4.19 | 9.99 | 4.75 | 10.25 |--epoch 90 --avg 20 | -| small | modified beam search| 4.05 | 9.89 | 4.68 | 10.01 |--epoch 90 --avg 20 | -| medium | greedy search | 2.39 | 4.85 | 2.90 | 6.6 |--epoch 60 --avg 20 | -| medium | modified beam search| 2.35 | 4.82 | 2.90 | 6.57 |--epoch 60 --avg 20 | -| large | greedy search | 1.67 | 3.32 | 2.24 | 5.61 |--epoch 16 --avg 3 | -| large | modified beam search| 1.62 | 3.36 | 2.20 | 5.57 |--epoch 16 --avg 3 | - -The training command is: -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -python ./zipformer/train.py \ - --world-size 4 \ - --master-port 12365 \ - --exp-dir zipformer/exp \ - --num-epochs 60 \ # 16 for large; 90 for small - --lr-hours 15000 \ # 20000 for large; 5000 for small - --use-fp16 1 \ - --start-epoch 1 \ - --bpe-model data/lang_bpe_500/bpe.model \ - --max-duration 1000 \ - --subset medium -``` - -The decoding command is: -```bash -export CUDA_VISIBLE_DEVICES="0" -for m in greedy_search modified_beam_search; do - ./zipformer/decode.py \ - --epoch 16 \ - --avg 3 \ - --exp-dir zipformer/exp \ - --max-duration 1000 \ - --causal 0 \ - --decoding-method $m -done -``` - -#### Training on full formatted text, i.e. with casing and punctuation - -##### normal-scaled model, number of model parameters: 66074067 , i.e., 66M - -You can find a pretrained model, training logs at: - - -Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), -exp_small_subset(small set). - -Results of models: - -| training set | decoding method | libriheavy clean (WER) | libriheavy other (WER) | libriheavy clean (CER) | libriheavy other (CER) | comment | -|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| -| small | modified beam search| 13.04 | 19.54 | 4.51 | 7.90 |--epoch 88 --avg 41 | -| medium | modified beam search| 9.84 | 13.39 | 3.02 | 5.10 |--epoch 50 --avg 15 | -| large | modified beam search| 7.76 | 11.32 | 2.41 | 4.22 |--epoch 16 --avg 2 | - -The training command is: -```bash -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -python ./zipformer/train.py \ - --world-size 4 \ - --master-port 12365 \ - --exp-dir zipformer/exp \ - --num-epochs 60 \ # 16 for large; 90 for small - --lr-hours 15000 \ # 20000 for large; 10000 for small - --use-fp16 1 \ - --train-with-punctuation 1 \ - --start-epoch 1 \ - --bpe-model data/lang_punc_bpe_756/bpe.model \ - --max-duration 1000 \ - --subset medium -``` - -The decoding command is: -```bash -export CUDA_VISIBLE_DEVICES="0" -for m in greedy_search modified_beam_search; do - ./zipformer/decode.py \ - --epoch 16 \ - --avg 3 \ - --exp-dir zipformer/exp \ - --max-duration 1000 \ - --causal 0 \ - --decoding-method $m -done -``` - -## Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) - -#### [zipformer_prompt_asr](./zipformer_prompt_asr) - -See for commit history and -our paper for more details. - - - -##### Training on the medium subset, with content & style prompt, **no** context list - -You can find a pre-trained model, training logs, decoding logs, and decoding results at: - -The training command is: - -```bash -causal=0 -subset=medium -memory_dropout_rate=0.05 -text_encoder_type=BERT - -python ./zipformer_prompt_asr/train_bert_encoder.py \ - --world-size 4 \ - --start-epoch 1 \ - --num-epochs 60 \ - --exp-dir ./zipformer_prompt_asr/exp \ - --use-fp16 True \ - --memory-dropout-rate $memory_dropout_rate \ - --causal $causal \ - --subset $subset \ - --manifest-dir data/fbank \ - --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ - --max-duration 1000 \ - --text-encoder-type $text_encoder_type \ - --text-encoder-dim 768 \ - --use-context-list 0 \ - --top-k $top_k \ - --use-style-prompt 1 -``` - -The decoding results using utterance-level context (epoch-60-avg-10): - -| decoding method | lh-test-clean | lh-test-other | comment | -|----------------------|---------------|---------------|---------------------| -| modified_beam_search | 3.13 | 6.78 | --use-pre-text False --use-style-prompt False | -| modified_beam_search | 2.86 | 5.93 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc | -| modified_beam_search | 2.6 | 5.5 | --pre-text-transform mixed-punc --style-text-transform mixed-punc | - - -The decoding command is: - -```bash -for style in mixed-punc upper-no-punc; do - python ./zipformer_prompt_asr/decode_bert.py \ - --epoch 60 \ - --avg 10 \ - --use-averaged-model True \ - --post-normalization True \ - --causal False \ - --exp-dir ./zipformer_prompt_asr/exp \ - --manifest-dir data/fbank \ - --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ - --max-duration 1000 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --text-encoder-type BERT \ - --text-encoder-dim 768 \ - --memory-layer 0 \ - --use-ls-test-set False \ - --use-ls-context-list False \ - --max-prompt-lens 1000 \ - --use-pre-text True \ - --use-style-prompt True \ - --style-text-transform $style \ - --pre-text-transform $style \ - --compute-CER 0 -done -``` - -##### Training on the medium subset, with content & style prompt, **with** context list - -You can find a pre-trained model, training logs, decoding logs, and decoding results at: - -This model is trained with an extra type of content prompt (context words), thus it does better -on **word-level** context biasing. Note that to train this model, please first run `prepare_prompt_asr.sh` -to prepare a manifest containing context words. - -The training command is: - -```bash - -causal=0 -subset=medium -memory_dropout_rate=0.05 -text_encoder_type=BERT -use_context_list=True - -# prepare the required data for context biasing -./prepare_prompt_asr.sh --stage 0 --stop_stage 1 - -python ./zipformer_prompt_asr/train_bert_encoder.py \ - --world-size 4 \ - --start-epoch 1 \ - --num-epochs 50 \ - --exp-dir ./zipformer_prompt_asr/exp \ - --use-fp16 True \ - --memory-dropout-rate $memory_dropout_rate \ - --causal $causal \ - --subset $subset \ - --manifest-dir data/fbank \ - --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ - --max-duration 1000 \ - --text-encoder-type $text_encoder_type \ - --text-encoder-dim 768 \ - --use-context-list $use_context_list \ - --top-k 10000 \ - --use-style-prompt 1 -``` - -*Utterance-level biasing:* - -| decoding method | lh-test-clean | lh-test-other | comment | -|----------------------|---------------|---------------|---------------------| -| modified_beam_search | 3.17 | 6.72 | --use-pre-text 0 --use-style-prompt 0 | -| modified_beam_search | 2.91 | 6.24 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc | -| modified_beam_search | 2.72 | 5.72 | --pre-text-transform mixed-punc --style-text-transform mixed-punc | - - -The decoding command for the table above is: - -```bash -for style in mixed-punc upper-no-punc; do - python ./zipformer_prompt_asr/decode_bert.py \ - --epoch 50 \ - --avg 10 \ - --use-averaged-model True \ - --post-normalization True \ - --causal False \ - --exp-dir ./zipformer_prompt_asr/exp \ - --manifest-dir data/fbank \ - --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ - --max-duration 1000 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --text-encoder-type BERT \ - --text-encoder-dim 768 \ - --memory-layer 0 \ - --use-ls-test-set False \ - --use-ls-context-list False \ - --max-prompt-lens 1000 \ - --use-pre-text True \ - --use-style-prompt True \ - --style-text-transform $style \ - --pre-text-transform $style \ - --compute-CER 0 -done -``` - -*Word-level biasing:* - -The results are reported on LibriSpeech test-sets using the biasing list provided from . -You need to set `--use-ls-test-set True` so that the LibriSpeech test sets are used. - -| decoding method | ls-test-clean | ls-test-other | comment | -|----------------------|---------------|---------------|---------------------| -| modified_beam_search | 2.4 | 5.08 | --use-pre-text 0 --use-style-prompt 0 | -| modified_beam_search | 2.14 | 4.62 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 0 | -| modified_beam_search | 2.14 | 4.64 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 100 | - -The decoding command is for the table above is: - -```bash -use_ls_test_set=1 -use_ls_context_list=1 - -for ls_distractors in 0 100; do - python ./zipformer_prompt_asr/decode_bert.py \ - --epoch 50 \ - --avg 10 \ - --use-averaged-model True \ - --post-normalization True \ - --causal False \ - --exp-dir ./zipformer_prompt_asr/exp \ - --manifest-dir data/fbank \ - --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ - --max-duration 1000 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --text-encoder-type BERT \ - --text-encoder-dim 768 \ - --memory-layer 0 \ - --use-ls-test-set $use_ls_test_setse \ - --use-ls-context-list $use_ls_context_list \ - --ls-distractors $ls_distractors \ - --max-prompt-lens 1000 \ - --use-pre-text True \ - --use-style-prompt True \ - --style-text-transform mixed-punc \ - --pre-text-transform mixed-punc \ - --compute-CER 0 -done - -``` diff --git a/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py deleted file mode 100755 index 010531db2..000000000 --- a/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the Libriheavy dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomChunkyWriter, -) - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--manifest-dir", - type=str, - help="""The source directory that contains raw manifests. - """, - default="data/manifests", - ) - - parser.add_argument( - "--fbank-dir", - type=str, - help="""Fbank output dir - """, - default="data/fbank", - ) - - parser.add_argument( - "--subset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - ) - - parser.add_argument( - "--num-workers", - type=int, - default=20, - help="Number of dataloading workers used for reading the audio.", - ) - - parser.add_argument( - "--batch-duration", - type=float, - default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", - ) - - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Whether to use speed perturbation.", - ) - - parser.add_argument( - "--use-splits", - type=str2bool, - default=False, - help="Whether to compute fbank on splits.", - ) - - parser.add_argument( - "--num-splits", - type=int, - help="""The number of splits of the medium and large subset. - Only needed when --use-splits is true.""", - ) - - parser.add_argument( - "--start", - type=int, - default=0, - help="""Process pieces starting from this number (inclusive). - Only needed when --use-splits is true.""", - ) - - parser.add_argument( - "--stop", - type=int, - default=-1, - help="""Stop processing pieces until this number (exclusive). - Only needed when --use-splits is true.""", - ) - - return parser.parse_args() - - -def compute_fbank_libriheavy(args): - src_dir = Path(args.manifest_dir) - output_dir = Path(args.fbank_dir) - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - subset = args.subset - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz" - if output_cuts_path.exists(): - logging.info(f"{output_cuts_path} exists - skipping") - return - - input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz" - assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!" - logging.info(f"Loading {input_cuts_path}") - cut_set = CutSet.from_file(input_cuts_path) - - logging.info("Computing features") - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/libriheavy_feats_{subset}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - logging.info(f"Saving to {output_cuts_path}") - cut_set.to_file(output_cuts_path) - - -def compute_fbank_libriheavy_splits(args): - num_splits = args.num_splits - subset = args.subset - src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split" - src_dir = Path(src_dir) - output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split" - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - start = args.start - stop = args.stop - if stop < start: - stop = num_splits - - stop = min(stop, num_splits) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - logging.info(f"device: {device}") - - num_digits = 8 # num_digits is fixed by lhotse split-lazy - for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) - logging.info(f"Processing {idx}/{num_splits}") - - cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" - if not raw_cuts_path.is_file(): - logging.info(f"{raw_cuts_path} does not exist - skipping it") - continue - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Computing features") - if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists(): - logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca") - os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca") - - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}", - num_workers=args.num_workers, - batch_duration=args.batch_duration, - overwrite=True, - ) - - logging.info("About to split cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - logging.info(f"Saved to {cuts_path}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - - if args.use_splits: - assert args.num_splits is not None, "Please provide num_splits" - compute_fbank_libriheavy_splits(args) - else: - compute_fbank_libriheavy(args) diff --git a/egs/libriheavy/ASR/local/compute_fbank_musan.py b/egs/libriheavy/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/libriheavy/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/local/norm_text.py b/egs/libriheavy/ASR/local/norm_text.py deleted file mode 100755 index c2fc0d92d..000000000 --- a/egs/libriheavy/ASR/local/norm_text.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import codecs -import sys - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--text", - type=str, - help="""Path to the input text. - """, - ) - return parser.parse_args() - - -def remove_punc_to_upper(text: str) -> str: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - s_list = [x.upper() if x in tokens else " " for x in text] - s = " ".join("".join(s_list).split()).strip() - return s - - -def main(): - args = get_args() - if args.text: - f = codecs.open(args.text, encoding="utf-8") - else: - f = codecs.getreader("utf-8")(sys.stdin.buffer) - - sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer) - line = f.readline() - while line: - print(remove_punc_to_upper(line)) - line = f.readline() - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py deleted file mode 100755 index a57a3749d..000000000 --- a/egs/libriheavy/ASR/local/prepare_manifest.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gzip -import json -import sys -from pathlib import Path - -from icefall.utils import str2bool - - -def simple_cleanup(text: str) -> str: - table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") - text = text.translate(table) - return text.strip() - - -# Assign text of the supervisions and remove unnecessary entries. -def main(): - assert ( - len(sys.argv) == 4 - ), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS" - fname = Path(sys.argv[1]).name - oname = Path(sys.argv[2]) / fname - keep_custom_fields = str2bool(sys.argv[3]) - with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: - for line in fin: - cut = json.loads(line) - cut["supervisions"][0]["text"] = simple_cleanup( - cut["supervisions"][0]["custom"]["texts"][0] - ) - if not keep_custom_fields: - del cut["supervisions"][0]["custom"] - del cut["custom"] - fout.write((json.dumps(cut) + "\n").encode()) - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py deleted file mode 100755 index 19caf43ab..000000000 --- a/egs/libriheavy/ASR/local/train_bpe_model.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--byte-fallback", - action="store_true", - help="""Whether to enable byte_fallback when training bpe.""", - ) - - parser.add_argument( - "--character-coverage", - type=float, - default=1.0, - help="Character coverage in vocabulary.", - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=args.character_coverage, - user_defined_symbols=user_defined_symbols, - byte_fallback=args.byte_fallback, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - else: - print(f"{model_file} exists - skipping") - return - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh deleted file mode 100755 index 366a1459f..000000000 --- a/egs/libriheavy/ASR/prepare.sh +++ /dev/null @@ -1,319 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=-1 -stop_stage=100 -export CUDA_VISIBLE_DEVICES="" - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/librilight -# You can find small, medium, large, etc. inside it. -# -# - $dl_dir/libriheavy -# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it. -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -dl_dir=$PWD/download - -# If you want to do PromptASR experiments, please set it to True -# as this will keep the texts and pre_text information required for -# the training of PromptASR. -keep_custom_fields=False - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - # 5000 - # 2000 - # 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data -fbank_dir=data/fbank -manifests_dir=data/manifests - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download audio data." - # If you have pre-downloaded it to /path/to/librilight, - # you can create a symlink - # - # ln -sfv /path/to/librilight $dl_dir/librilight - # - mkdir -p $dl_dir/librilight - for subset in small medium large; do - log "Downloading ${subset} subset." - if [ ! -d $dl_dir/librilight/${subset} ]; then - wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar - tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight - else - log "Skipping download, ${subset} subset exists." - fi - done -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download manifests from huggingface." - - # If you have pre-downloaded it to /path/to/libriheavy, - # you can create a symlink - # - # ln -sfv /path/to/libriheavy $dl_dir/libriheavy - # - mkdir -p $dl_dir/libriheavy - for subset in small medium large dev test_clean test_other; do - if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then - log "Downloading ${subset} subset." - wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz - else - log "Skipping download, ${subset} subset exists." - fi - done - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Download manifests from modelscope" - mkdir -p $dl_dir/libriheavy - if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_small.jsonl.gz ]; then - cd $dl_dir/libriheavy - GIT_LFS_SKIP_SMUDGE=1 git clone https://www.modelscope.cn/datasets/pkufool/Libriheavy.git - cd Libriheavy - git lfs pull --exclude "raw/*" - mv *.jsonl.gz ../ - cd .. - rm -rf Libriheavy - cd ../../ - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p $manifests_dir - if [ ! -e $manifests_dir/.musan.done ]; then - lhotse prepare musan $dl_dir/musan $manifests_dir - touch $manifests_dir/.musan.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare Libriheavy manifests" - mkdir -p $manifests_dir - for subset in small medium large dev test_clean test_other; do - if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then - log "Prepare manifest for subset : ${subset}" - ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields - fi - done -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - mkdir -p $fbank_dir - if [ ! -e $fbank_dir/.musan.done ]; then - ./local/compute_fbank_musan.py - touch $fbank_dir/.musan.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for small subset and validation subsets" - for subset in test_clean test_other dev small; do - log "Computing $subset subset." - if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then - ./local/compute_fbank_libriheavy.py \ - --manifest-dir ${manifests_dir} \ - --subset ${subset} \ - --fbank-dir $fbank_dir \ - --num-workers $nj - fi - done -fi - -num_per_split=8000 -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Split medium and large subsets." - for subset in medium large; do - log "Spliting subset : $subset" - split_dir=$manifests_dir/libriheavy_${subset}_split - mkdir -p $split_dir - if [ ! -e $split_dir/.split_completed ]; then - lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split - touch $split_dir/.split_completed - fi - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compute fbank for medium and large subsets" - mkdir -p $fbank_dir - chunk_size=20 - for subset in medium large; do - if [ $subset == "large" ]; then - chunk_size=200 - fi - num_splits=$(find $manifests_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz" | wc -l) - if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then - for i in $(seq 0 1 6); do - start=$(( i * $chunk_size )) - end=$(( (i+1) * $chunk_size )) - ./local/compute_fbank_libriheavy.py \ - --manifest-dir ${manifests_dir} \ - --use-splits 1 \ - --subset ${subset} \ - --fbank-dir $fbank_dir \ - --num-splits $num_splits \ - --num-workers $nj \ - --start $start \ - --stop $end & - done - wait - touch $fbank_dir/.libriheavy.${subset}.done - fi - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Combine features for medium and large subsets." - for subset in medium large; do - log "Combining $subset subset." - if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then - pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz") - lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz - fi - done -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Train BPE model for normalized text" - - if [ ! -f data/texts ]; then - gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ - | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ - | ./local/norm_text.py > data/texts - fi - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - - cp data/texts $lang_dir/text - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/text - fi - done -fi - - -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Train BPE model for unnormalized text" - if [ ! -f data/punc_texts ]; then - gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ - | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts - fi - for vocab_size in ${vocab_sizes[@]}; do - new_vocab_size=$(($vocab_size + 256)) - lang_dir=data/lang_punc_bpe_${new_vocab_size} - mkdir -p $lang_dir - - cp data/punc_texts $lang_dir/text - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --byte-fallback \ - --vocab-size ${new_vocab_size} \ - --byte-fallback \ - --character-coverage 0.99 \ - --transcript $lang_dir/text - fi - done -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Prepare language model for normalized text" - - for subset in small medium large; do - if [ ! -f $manifests_dir/texts_${subset} ]; then - gunzip -c $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \ - | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ - | ./local/norm_text.py > $manifests_dir/texts_${subset} - fi - done - - mkdir -p data/lm - if [ ! -f data/lm/text ]; then - cat $manifests_dir/texts_small $manifests_dir/texts_medium $manifests_dir/texts_large > data/lm/text - fi - - (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ - > data/lm/words.txt - - cat data/lm/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ - | awk '{print $1" "NR+3}' >> data/lm/words.txt - - num_lines=$(< data/lm/words.txt wc -l) - (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ - >> data/lm/words.txt - - # Train LM on transcripts - if [ ! -f data/lm/3-gram.unpruned.arpa ]; then - python3 ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text data/lm/text \ - -lm data/lm/3-gram.unpruned.arpa - fi - - # We assume you have install kaldilm, if not, please install - # it using: pip install kaldilm - if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table=data/lm/words.txt \ - --disambig-symbol='#0' \ - --max-order=3 \ - data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt - fi -fi - diff --git a/egs/libriheavy/ASR/prepare_prompt_asr.sh b/egs/libriheavy/ASR/prepare_prompt_asr.sh deleted file mode 100755 index b931cea26..000000000 --- a/egs/libriheavy/ASR/prepare_prompt_asr.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -stage=-1 -stop_stage=100 -manifest_dir=data/fbank -subset=medium -topk=10000 - -. shared/parse_options.sh || exit 1 - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download the meta biasing list for LibriSpeech" - mkdir -p data/context_biasing - cd data/context_biasing - git clone https://github.com/facebookresearch/fbai-speech.git - cd ../.. -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Add rare-words for context biasing to the manifest" - python zipformer_prompt_asr/utils.py \ - --manifest-dir $manifest_dir \ - --subset $subset \ - --top-k $topk - -fi diff --git a/egs/libriheavy/ASR/shared b/egs/libriheavy/ASR/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/libriheavy/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py deleted file mode 100644 index 4985f3f4c..000000000 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriHeavyAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--subset", - type=str, - default="S", - help="""The subset to be used. Should be S, M or L. Note: S subset - includes libriheavy_cuts_small.jsonl.gz, M subset includes - libriheavy_cuts_small.jsonl.gz and libriheavy_cuts_medium.jsonl.gz, - L subset includes libriheavy_cuts_small.jsonl.gz, - libriheavy_cuts_medium.jsonl.gz and libriheavy_cuts_large.jsonl.gz. - """, - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_small_cuts(self) -> CutSet: - logging.info("About to get small subset cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" - ) - - @lru_cache() - def train_medium_cuts(self) -> CutSet: - logging.info("About to get medium subset cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" - ) - - @lru_cache() - def train_large_cuts(self) -> CutSet: - logging.info("About to get large subset cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz" - ) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get the test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get the test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz" - ) diff --git a/egs/libriheavy/ASR/zipformer/beam_search.py b/egs/libriheavy/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/libriheavy/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py deleted file mode 100644 index 1928e2635..000000000 --- a/egs/libriheavy/ASR/zipformer/decode.py +++ /dev/null @@ -1,794 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -""" - - -import argparse -import logging -import math -import warnings -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriHeavyAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from text_normalization import remove_punc_to_upper -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--train-with-punctuation", - type=str2bool, - default=False, - help="""Set to True, if the model was trained on texts with casing - and punctuation.""", - ) - - parser.add_argument( - "--post-normalization", - type=str2bool, - default=False, - help="""Upper case and remove all chars except ' and - - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - this_batch = [] - if params.post_normalization and params.train_with_punctuation: - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = remove_punc_to_upper(ref_text).split() - hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[f"{name}_norm"].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriHeavyAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - libriheavy = LibriHeavyAsrDataModule(args) - - def normalize_text(c: Cut): - text = remove_punc_to_upper(c.supervisions[0].text) - c.supervisions[0].text = text - return c - - test_clean_cuts = libriheavy.test_clean_cuts() - test_other_cuts = libriheavy.test_other_cuts() - - if not params.train_with_punctuation: - test_clean_cuts = test_clean_cuts.map(normalize_text) - test_other_cuts = test_other_cuts.map(normalize_text) - - test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts) - test_other_dl = libriheavy.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer/decoder.py b/egs/libriheavy/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/libriheavy/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/encoder_interface.py b/egs/libriheavy/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/libriheavy/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export-onnx.py b/egs/libriheavy/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/libriheavy/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export.py b/egs/libriheavy/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/libriheavy/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/jit_pretrained.py b/egs/libriheavy/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/libriheavy/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/joiner.py b/egs/libriheavy/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/libriheavy/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/model.py b/egs/libriheavy/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/libriheavy/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_decode.py b/egs/libriheavy/ASR/zipformer/onnx_decode.py deleted file mode 120000 index 0573b88c5..000000000 --- a/egs/libriheavy/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_pretrained.py b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/libriheavy/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/optim.py b/egs/libriheavy/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/libriheavy/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/pretrained.py b/egs/libriheavy/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/libriheavy/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling.py b/egs/libriheavy/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/libriheavy/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling_coverter.py b/egs/libriheavy/ASR/zipformer/scaling_coverter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/libriheavy/ASR/zipformer/scaling_coverter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/subsampling.py b/egs/libriheavy/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/libriheavy/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py deleted file mode 100644 index 92590769c..000000000 --- a/egs/libriheavy/ASR/zipformer/text_normalization.py +++ /dev/null @@ -1,50 +0,0 @@ -from num2words import num2words - - -def remove_punc_to_upper(text: str) -> str: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - s_list = [x.upper() if x in tokens else " " for x in text] - s = " ".join("".join(s_list).split()).strip() - return s - - -def word_normalization(word: str) -> str: - # 1. Use full word for some abbreviation - # 2. Convert digits to english words - # 3. Convert ordinal number to english words - if word == "MRS": - return "MISSUS" - if word == "MR": - return "MISTER" - if word == "ST": - return "SAINT" - if word == "ECT": - return "ET CETERA" - - if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH - word = num2words(word[:-2], to="ordinal") - word = word.replace("-", " ") - - if word.isnumeric(): - num = int(word) - if num > 1500 and num < 2030: - word = num2words(word, to="year") - else: - word = num2words(word) - word = word.replace("-", " ") - return word.upper() - - -def text_normalization(text: str) -> str: - text = text.upper() - return " ".join([word_normalization(x) for x in text.split()]) - - -if __name__ == "__main__": - assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK" - assert ( - text_normalization("Hello Mrs st 21st world 3rd she 99th MR") - == "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER" - ) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py deleted file mode 100644 index 357e8a827..000000000 --- a/egs/libriheavy/ASR/zipformer/train.py +++ /dev/null @@ -1,1414 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --full-libri 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriHeavyAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from text_normalization import remove_punc_to_upper -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-hours", - type=float, - default=30000, - help="""Number of hours that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--train-with-punctuation", - type=str2bool, - default=False, - help="If True, the training text will include casing and punctuation.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - # Use the number of hours of speech to adjust the learning rate - scheduler.step_epoch( - params.batch_idx_train * params.max_duration * params.world_size / 3600 - ) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_hours) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def normalize_text(c: Cut): - text = remove_punc_to_upper(c.supervisions[0].text) - c.supervisions[0].text = text - return c - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 2.0 or c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - libriheavy = LibriHeavyAsrDataModule(args) - - train_cuts = libriheavy.train_small_cuts() - if params.subset == "M" or params.subset == "L": - train_cuts += libriheavy.train_medium_cuts() - if params.subset == "L": - train_cuts += libriheavy.train_large_cuts() - - if not params.train_with_punctuation: - train_cuts = train_cuts.map(normalize_text) - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = libriheavy.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = libriheavy.dev_cuts() - - if not params.train_with_punctuation: - valid_cuts = valid_cuts.map(normalize_text) - - valid_dl = libriheavy.valid_dataloaders(valid_cuts) - - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # sp=sp, - # params=params, - # ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriHeavyAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer/zipformer.py b/egs/libriheavy/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/libriheavy/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py b/egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py deleted file mode 100644 index 552f63905..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ /dev/null @@ -1,524 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional - -import torch -from dataset import PromptASRDataset -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # SingleCutSampler, - CutConcatenate, - CutMix, - DynamicBucketingSampler, - ExtraPadding, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriHeavyAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - if args.use_context_list: - assert args.rare_word_file is not None - with open(args.rare_word_file, "r") as f: - self.rare_word_list = ( - f.read().lower().split() - ) # Use lower-cased for easier style transform - else: - self.rare_word_list = None - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", - ) - - # Libriheavy specific arguments - group.add_argument( - "--subset", - type=str, - default="small", - help="Select the Libriheavy subset (small|medium|large)", - ) - - group.add_argument( - "--use-context-list", - type=str2bool, - default=False, - help="Use the context list of libri heavy", - ) - - group.add_argument( - "--top-k", - type=int, - default=10000, - help="""The top-k words are identified as common words, - the rest as rare words""", - ) - - group.add_argument( - "--with-decoding", - type=str2bool, - default=False, - help="If the texts field contain decoding", - ) - - group.add_argument( - "--random-left-padding", - type=str2bool, - ) - - group.add_argument( - "--rare-word-file", - type=str, - ) - - group.add_argument( - "--long-audio-cuts", - type=str, - default="data/manifest_npr/npr1_cuts_all_guids_0.jsonl.gz", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - text_sampling_func: Callable[[List[str]], str] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = PromptASRDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - text_sampling_func=text_sampling_func, - rare_word_list=self.rare_word_list, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = PromptASRDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - text_sampling_func=text_sampling_func, - rare_word_list=self.rare_word_list, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - raise NotImplementedError( - "SingleCutSampler is no longer supported by lhotse" - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders( - self, - cuts_valid: CutSet, - text_sampling_func: Callable[[List[str]], str] = None, - ) -> DataLoader: - transforms = [] - if self.args.random_left_padding: - logging.info("Enable random left padding") - transforms.append( - ExtraPadding(extra_frames=16, randomized=True, direction="left") - ) - - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = PromptASRDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - text_sampling_func=text_sampling_func, - rare_word_list=self.rare_word_list, - ) - else: - validate = PromptASRDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - text_sampling_func=text_sampling_func, - rare_word_list=self.rare_word_list, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures() - ), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info(f"About to get {self.args.subset} cuts") - - if self.args.use_context_list: - path = ( - self.args.manifest_dir - / f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz" - ) - elif self.args.with_decoding: - path = ( - self.args.manifest_dir - / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz" - ) - else: - path = ( - self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz" - ) - - logging.info(f"Loading manifest from {path}.") - cuts_train = CutSet.from_jsonl_lazy(path) - return cuts_train - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" - ) - return cuts_valid - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_test-clean_official.jsonl.gz" - ) - return cuts_valid - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_test-other_official.jsonl.gz" - ) - return cuts_valid - - @lru_cache() - def librispeech_test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def librispeech_test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - ) - - @lru_cache() - def long_audio_cuts(self) -> CutSet: - logging.info("About to get long audio cuts") - cuts = load_manifest_lazy( - self.args.long_audio_cuts, - ) - return cuts - - @lru_cache() - def test_dev_cuts(self) -> CutSet: - logging.info("About to get test dev cuts") - cuts = load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz" - ) - return cuts diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py b/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py deleted file mode 100644 index e0bf8f73d..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py +++ /dev/null @@ -1,586 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -from typing import Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from lhotse import validate -from lhotse.cut import CutSet -from lhotse.dataset import K2SpeechRecognitionDataset -from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures -from lhotse.utils import compute_num_frames, ifnone -from text_normalization import ( - lower_all_char, - lower_only_alpha, - remove_non_alphabetic, - train_text_normalization, - upper_all_char, - upper_only_alpha, -) -from torch.utils.data.dataloader import DataLoader, default_collate - - -class PromptASRDataset(torch.utils.data.Dataset): - """This is a dataset for Prompt ASR. It supports the following features: - 1. Select a tuple of (text, pre_text, style_text) randomly from a - list of texts as supervisions. - - """ - - def __init__( - self, - return_cuts: bool = False, - cut_transforms: List[Callable[[CutSet], CutSet]] = None, - input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, - input_strategy: BatchIO = PrecomputedFeatures(), - text_sampling_func: Optional[Callable[[List[str]], str]] = None, - rare_word_list: Optional[List[str]] = None, - ): - """ - Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py - for more details. - - :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut - objects used to create that batch. - :param cut_transforms: A list of transforms to be applied on each sampled batch, - before converting cuts to an input representation (audio/features). - Examples: cut concatenation, noise cuts mixing, etc. - :param input_transforms: A list of transforms to be applied on each sampled batch, - after the cuts are converted to audio/features. - Examples: normalization, SpecAugment, etc. - :param input_strategy: Converts cuts into a collated batch of audio/features. - By default, reads pre-computed features from disk. - :param text_sampling_func: Sampling a text as transcription from a list of texts. - """ - super().__init__() - # Initialize the fields - self.return_cuts = return_cuts - self.cut_transforms = ifnone(cut_transforms, []) - self.input_transforms = ifnone(input_transforms, []) - self.input_strategy = input_strategy - - # a text sampling function - self.text_sampling_func = text_sampling_func - self.rare_word_list = rare_word_list - - def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: - """ - Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. - """ - validate_for_asr(cuts) - - # Sort the cuts by duration so that the first one determines the batch time dimensions. - cuts = cuts.sort_by_duration(ascending=False) - - # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts - # the supervision boundaries. - for tnfm in self.cut_transforms: - cuts = tnfm(cuts) - - # Sort the cuts again after transforms - cuts = cuts.sort_by_duration(ascending=False) - - # Get a tensor with batched feature matrices, shape (B, T, F) - # Collation performs auto-padding, if necessary. - input_tpl = self.input_strategy(cuts) - if len(input_tpl) == 3: - # An input strategy with fault tolerant audio reading mode. - # "cuts" may be a subset of the original "cuts" variable, - # that only has cuts for which we succesfully read the audio. - inputs, _, cuts = input_tpl - else: - inputs, _ = input_tpl - - # Get a dict of tensors that encode the positional information about supervisions - # in the batch of feature matrices. The tensors are named "sequence_idx", - # "start_frame/sample" and "num_frames/samples". - supervision_intervals = self.input_strategy.supervision_intervals(cuts) - - # Apply all available transforms on the inputs, i.e. either audio or features. - # This could be feature extraction, global MVN, SpecAugment, etc. - segments = torch.stack(list(supervision_intervals.values()), dim=1) - for tnfm in self.input_transforms: - inputs = tnfm(inputs, supervision_segments=segments) - - batch = { - "inputs": inputs, - "supervisions": default_collate( - [ - self.text_sampling_func( - texts=supervision.texts, - pre_texts=supervision.pre_texts, - context_list=supervision.context_list - if "context_list" in supervision.custom - else None, - rare_word_list=self.rare_word_list, - ) - if self.text_sampling_func is not None - else { - "text": train_text_normalization(supervision.texts[0]), - "pre_text": train_text_normalization(supervision.pre_texts[0]), - "style_text": train_text_normalization( - supervision.pre_texts[0] - ), - "transform_ids": 0, - } - for sequence_idx, cut in enumerate(cuts) - for supervision in cut.supervisions - ] - ), - } - # Update the 'supervisions' field with sequence_idx and start/num frames/samples - batch["supervisions"].update(supervision_intervals) - if self.return_cuts: - batch["supervisions"]["cut"] = [ - cut for cut in cuts for sup in cut.supervisions - ] - - has_word_alignments = all( - s.alignment is not None and "word" in s.alignment - for c in cuts - for s in c.supervisions - ) - - return batch - - -def validate_for_asr(cuts: CutSet) -> None: - validate(cuts) - tol = 2e-3 # 1ms - for cut in cuts: - for supervision in cut.supervisions: - assert supervision.start >= -tol, ( - f"Supervisions starting before the cut are not supported for ASR" - f" (sup id: {supervision.id}, cut id: {cut.id})" - ) - - # Supervision start time is relative to Cut ... - # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html - # - # 'supervision.end' is end of supervision inside the Cut - assert supervision.end <= cut.duration + tol, ( - f"Supervisions ending after the cut " - f"are not supported for ASR" - f" (sup id: {supervision.id}, cut id: {cut.id})" - ) - - -def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str: - """A helper function that generates a random substring from a given string - - Args: - s (str): Input string - - Returns: - str: Returned substring - """ - min_len = min(len(s), min_len) - - start = random.randint(0, len(s) - min_len) - end = min(start + max_len, random.randint(start + min_len, len(s))) - - return s[start:end] - - -def triplet_text_sampling( - texts: List[str], - pre_texts: List[str], - context_list: Optional[str] = None, - rare_word_list: Optional[List[str]] = None, - transforms: Optional[List[Callable[[str], str]]] = None, - min_len_style: Optional[int] = 80, -) -> Dict[str, str]: - """This function generates a triplet of - (pre_text, style_text, ref_text). The style of style_text and ref_text - should **always** match, whereas the style of pre_text is arbitrary. - Suppose we have 2 different transforms A,B, and the preceding text is - referred to as pre_text. The following three tuples are all valid: - - (A(pre_text), A(style_text), A(ref_text)) - (A(pre_text), B(style_text), B(ref_text)) - (A(pre_text), A(style_text), A(ref_text)) - (B(pre_text), B(style_text), B(ref_text)) - - If transforms is not given, the following pre-defined transforms - are available: - 0: original (mixed-cased, with punc) - 1: upper_only_alpha (upper-cased, no punc) - - When the transform of text and pre_text match, we can use the whole - pre_text as the prompt text. - - Args: - texts (List[str]): - A list of ref_texts whose first item is the ground truth - text from books. - pre_texts (List[str]): - A list of pre_texts, whose first item is the groundtruth - pre_text from books. - context_list: Optional[str] = None, - A list of biasing words separated by space - rare_word_list: Optional[str] = None, - A list of rare-words separated by space (used as distractors) - transforms (List[Callable[[str], str]]): A list of possible transforms to be applied - - Returns: - A dictionary of ref_text, pre_text, style_text - """ - assert len(texts) == len(pre_texts) - assert len(texts) == 2 - - # we assume the first item to be ground truth - gt_text = texts[0] - gt_pre_text = pre_texts[0] - - if transforms is None: - transforms = [ - lambda x: x, # return it self - upper_only_alpha, - lower_only_alpha, - lower_all_char, - ] - - sampling_weight = [ - 0.7, - 0.3, - 0.0, - 0.0, - ] # Mixed-punc should have the largest sampling prob - - total_transforms = len(transforms) # do not use the recognized trans - - # Randomly sample transforms - i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) - - # get the normalized text and pre_text - text = transforms[i_text](gt_text) - pre_text = transforms[i_pre_text](gt_pre_text) - - if i_text == i_pre_text: - style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) - else: - # get the pre_text of same style as text - # For now, **don't** do transform to the style text, because we do it after the dataloader - style_text = gt_pre_text - # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) - style_text = get_substring(style_text, min_len=min_len_style, max_len=150) - - return { - "text": train_text_normalization(text), - "pre_text": train_text_normalization(pre_text), - "style_text": train_text_normalization(style_text), - "transform_ids": i_text, - } - - -def triplet_text_sampling_with_context_list( - texts: List[str], - pre_texts: List[str], - context_list: str, - rare_word_list: List[str], - transforms: Optional[List[Callable[[str], str]]] = None, - min_len_style: Optional[int] = 80, -) -> Dict[str, str]: - """This function generates a triplet of - (pre_text, style_text, ref_text). The pre_text is either the preceding text - or a list of words (context words + distractors). - The style of style_text and ref_text should **always** match, whereas - the style of pre_text is arbitrary. - Suppose we have 2 different transforms A,B, and the preceding text is - referred to as pre_text. The following three tuples are all valid: - - (A(pre_text), A(style_text), A(ref_text)) - (A(pre_text), B(style_text), B(ref_text)) - (A(pre_text), A(style_text), A(ref_text)) - (B(pre_text), B(style_text), B(ref_text)) - - If transforms is not given, the following pre-defined transforms - are available: - 0: original (mixed-cased, with punc) - 1: upper_only_alpha (upper-cased, no punc) - - When the transform of text and pre_text match, we can use the whole - pre_text as the prompt text. - - Args: - texts (List[str]): - A list of ref_texts whose first item is the ground truth - text from books. - pre_texts (List[str]): - A list of pre_texts, whose first item is the groundtruth - pre_text from books. - context_list: Optional[str] = None, - A list of biasing words separated by space - rare_word_list: Optional[str] = None, - A list of rare-words separated by space (used as distractors) - transforms (List[Callable[[str], str]]): A list of possible transforms to be applied - - Returns: - A dictionary of ref_text, pre_text, style_text - Returns: - str: A dictionary - """ - # import pdb; pdb.set_trace() - assert len(texts) == len(pre_texts) - assert len(texts) == 2 - - if context_list is not None: - context_list = context_list.lower() - - # we assume the first item to be ground truth - gt_text = texts[0] - gt_pre_text = pre_texts[0] - - if transforms is None: - transforms = [ - lambda x: x, # return it self - upper_only_alpha, - lower_only_alpha, - lower_all_char, - ] - - sampling_weight = [ - 0.7, - 0.3, - 0.0, - 0.0, - ] # Mixed-punc should have the largest sampling prob - - total_transforms = len(transforms) # do not use the recognized trans - - # Select a transformation randomly - i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) - - # get the normalized text and pre_text - text = transforms[i_text](gt_text) - pre_text = get_pre_text_with_context_list2( - text=gt_text, - pre_text=gt_pre_text, - context_list=context_list, - rare_words_list=rare_word_list, - ) - pre_text = transforms[i_pre_text](pre_text) - - if i_text == i_pre_text: - style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) - else: - # get the pre_text of same style as text - # For now, **don't** do transform to the style text - style_text = gt_pre_text - # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) - style_text = get_substring(style_text, min_len=min_len_style, max_len=150) - - return { - "text": train_text_normalization(text), - "pre_text": train_text_normalization(pre_text), - "style_text": train_text_normalization(style_text), - "transform_ids": i_text, - } - - -def get_pre_text_with_context_list( - text: str, - pre_text: str, - context_list: str, - rare_words_list: List[str] = None, -) -> str: - # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha - # By a small proportion of time, use the substring of ref_text as pre_text - - if context_list != "" and context_list is not None: - v = random.random() - if v < 0.5: - # correct + distractors - # sample distractors - num_distractors = random.randint(0, 50) - distractors = random.sample(rare_words_list, num_distractors) - # sample correct - correct = context_list.split() - i = random.randint(1, len(correct)) - correct = random.sample(correct, i) - # combine correct and distractors - pre_text = distractors + correct - random.shuffle(pre_text) - pre_text = " ".join(pre_text) - elif v < 0.7: - splitted = text.split() - sampling_weights = [len(w) ** 1.2 for w in splitted] - sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] - i = random.randint(1, min(len(splitted), 20)) - splitted = list(np.random.choice(splitted, i, p=sampling_weights)) - num_distractors = random.randint(0, 70) - distractors = random.sample(rare_words_list, num_distractors) - splitted += distractors - random.shuffle(splitted) # shuffle the list - pre_text = " ".join(splitted) - else: - pre_text = pre_text - else: - v = random.random() - if v < 0.1: - splitted = text.split() - sampling_weights = [len(w) ** 1.2 for w in splitted] - sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] - i = random.randint(1, min(len(splitted), 20)) - splitted = list(np.random.choice(splitted, i, p=sampling_weights)) - pre_text = " ".join(splitted) - num_distractors = random.randint(0, 70) - distractors = random.sample(rare_words_list, num_distractors) - splitted += distractors - random.shuffle(splitted) # shuffle the list - elif v < 0.2: - # full distractors - num_distractors = random.randint(5, 100) - distractors = random.sample(rare_words_list, num_distractors) - pre_text = " ".join(distractors) - - elif v < 0.3: - pre_text = get_substring(text, min_len=15, max_len=150) - else: - pre_text = pre_text - - return pre_text - - -def get_pre_text_with_context_list2( - text: str, - pre_text: str, - context_list: str, - rare_words_list: List[str] = None, -) -> str: - # Get the pre_text, either the ground truth preceding text or - # a list of words consisting of biasing words and distrators - # By a small proportion of time, use the substring of ref_text as pre_text - - if context_list != "" and context_list is not None: - v = random.random() - if v < 0.4: - # sample distractors - num_distractors = random.randint(50, 100) - distractors = random.sample(rare_words_list, num_distractors) - # sample correct - correct = context_list.split() - i = random.randint(1, len(correct)) - correct = random.sample(correct, i) - # combine correct and distractors - pre_text = distractors + correct - random.shuffle(pre_text) - pre_text = " ".join(pre_text) - elif v < 0.55: - splitted = text.split() - sampling_weights = [ - len(w) ** 1.2 for w in splitted - ] # longer words with higher weights - sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] - i = random.randint(1, min(len(splitted), 20)) - splitted = list(np.random.choice(splitted, i, p=sampling_weights)) - num_distractors = random.randint(50, 100) - distractors = random.sample(rare_words_list, num_distractors) - splitted += distractors - random.shuffle(splitted) # shuffle the list - pre_text = " ".join(splitted) - else: - pre_text = pre_text - else: - v = random.random() - if v < 0.3: - splitted = text.split() - sampling_weights = [len(w) ** 1.2 for w in splitted] - sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] - i = random.randint(1, min(len(splitted), 20)) - splitted = list(np.random.choice(splitted, i, p=sampling_weights)) - pre_text = " ".join(splitted) - num_distractors = random.randint(50, 100) - distractors = random.sample(rare_words_list, num_distractors) - splitted += distractors - random.shuffle(splitted) # shuffle the list - elif v < 0.4: - # full distractors - num_distractors = random.randint(5, 100) - distractors = random.sample(rare_words_list, num_distractors) - pre_text = " ".join(distractors) - elif v < 0.6: - pre_text = get_substring(text, min_len=15, max_len=150) - else: - pre_text = pre_text - - return pre_text - - -def naive_triplet_text_sampling( - texts: List[str], - pre_texts: List[str], - context_list: str = None, - rare_word_list: List[str] = None, - min_len_style: Optional[int] = 120, -): - # The most simplest text sampling function, used only for - # evaluation, use a fixed sentence as the style text - - return { - "text": train_text_normalization(texts[0]), - "pre_text": train_text_normalization(pre_texts[0]), - "style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?", - "transform_ids": 0, - } - - -def random_shuffle_subset( - data: List[str], - p: float = 0.2, - p_mask: float = 0.05, -) -> List[str]: - """ - Randomly shuffle the subset by probability `p`, which means that p% of the samples - in the original batch are shuffled, the others are kept in the original order. - - With a probability of `p_mask`, replace the original string with an empty string. - - """ - - num_to_shuffle = int(len(data) * p) - id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False) - item_to_shuffle = [data[id] for id in id_to_shuffle] - random.shuffle(item_to_shuffle) - - for id, item in zip(id_to_shuffle, item_to_shuffle): - data[id] = item - - # Randomly mask a proportion of the data to empty string - if p_mask > 0: - for i in range(len(data)): - if random.random() < p_mask: - data[i] = "" - - return data - - -if __name__ == "__main__": - texts = [ - "AA, BB, cC, dD!", - "AA BB CC DD", - ] - - pre_texts = [ - "EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?", - "EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG", - ] - for i in range(10): - print(f"Run: {i}") - print(triplet_text_sampling(texts, pre_texts)) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py deleted file mode 100644 index 6a3bab3c8..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py +++ /dev/null @@ -1,791 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -""" - - -import argparse -import logging -import math -import warnings -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriHeavyAsrDataModule -from beam_search import greedy_search, greedy_search_batch, modified_beam_search -from ls_text_normalization import word_normalization -from text_normalization import ( - ref_text_normalization, - remove_non_alphabetic, - upper_only_alpha, -) -from train_baseline import add_model_arguments, get_params, get_transducer_model -from utils import write_error_stats - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion - - modified_beam_search_LODR - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--post-normalization", - type=str2bool, - default=True, - help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", - ) - - parser.add_argument( - "--long-audio-recog", - type=str2bool, - default=False, - ) - - parser.add_argument( - "--use-ls-test-set", - type=str2bool, - default=False, - help="Use librispeech test set for evaluation.", - ) - - parser.add_argument( - "--compute-CER", - type=str2bool, - default=True, - help="Reports CER. By default, only reports WER", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` - set to true. - ngram_lm: - A ngram lm. Used in LODR decoding. - ngram_lm_scale: - The scale of the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - texts = batch["supervisions"]["text"] - batch_size = feature.size(0) - - # Get the transducer encoder output - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - encoder_out, encoder_out_lens = model.encode_audio( - feature=feature, - feature_lens=feature_lens, - ) - - hyps = [] - - if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network LM, used during shallow fusion - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - if not params.use_ls_test_set: - book_names = [ - cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] - ] - else: - book_names = ["" for _ in cut_ids] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, book_name, hyp_words, ref_text in zip( - cut_ids, book_names, hyps, texts - ): - ref_text = ref_text_normalization(ref_text) - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - # if not params.use_ls_test_set: - # results[name + " " + book_name].extend(this_batch) - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], - biasing_words: List[str] = None, -): - test_set_wers = dict() - test_set_cers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - biasing_words=biasing_words, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - if params.compute_CER: - # Write CER statistics - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" - ) - store_transcripts(filename=recog_path, texts=results, char_level=True) - errs_filename = ( - params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - cer = write_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - compute_CER=params.compute_CER, - ) - test_set_cers[key] = cer - - logging.info("Wrote detailed CER stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - if params.compute_CER: - test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tcER", file=f) - for key, val in test_set_cers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_cers: - s += "{} CER\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriHeavyAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "modified_beam_search", - ) - - if params.long_audio_recog: - params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") - else: - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if "ngram" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.to(device) - model.eval() - - LM = None - - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - libriheavy = LibriHeavyAsrDataModule(args) - - test_clean_cuts = libriheavy.test_clean_cuts() - test_other_cuts = libriheavy.test_other_cuts() - ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() - ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() - long_audio_cuts = libriheavy.long_audio_cuts() - - test_clean_dl = libriheavy.valid_dataloaders( - test_clean_cuts, - ) - test_other_dl = libriheavy.valid_dataloaders( - test_other_cuts, - ) - ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) - ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) - long_audio_dl = libriheavy.valid_dataloaders( - long_audio_cuts, - ) - - if params.use_ls_test_set: - test_sets = ["ls-test-clean", "ls-test-other"] - test_dl = [ls_test_clean_dl, ls_test_other_dl] - else: - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - if params.long_audio_recog: - test_sets = ["long-audio"] - test_dl = [long_audio_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - if params.use_ls_test_set: - f = open( - "data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r" - ) - biasing_words = f.read().strip().split() - f.close() - else: - biasing_words = None - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - if params.post_normalization: - if "-post-normalization" not in params.suffix: - params.suffix += "-post-normalization" - - new_res = {} - for k in results_dict: - new_ans = [] - for item in results_dict[k]: - id, ref, hyp = item - if params.use_ls_test_set: - hyp = ( - " ".join(hyp).replace("-", " ").split() - ) # handle the hypens - hyp = upper_only_alpha(" ".join(hyp)).split() - hyp = [word_normalization(w.upper()) for w in hyp] - hyp = " ".join(hyp).split() - hyp = [w for w in hyp if w != ""] - ref = upper_only_alpha(" ".join(ref)).split() - else: - hyp = upper_only_alpha(" ".join(hyp)).split() - ref = upper_only_alpha(" ".join(ref)).split() - new_ans.append((id, ref, hyp)) - new_res[k] = new_ans - - save_results( - params=params, - test_set_name=test_set, - results_dict=new_res, - biasing_words=biasing_words, - ) - - if params.suffix.endswith("-post-normalization"): - params.suffix = params.suffix.replace("-post-normalization", "") - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py deleted file mode 100755 index e71999b0a..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py +++ /dev/null @@ -1,1025 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer_prompt_asr/decode_bert.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer_prompt_asr/exp \ - --max-duration 1000 \ - --decoding-method greedy_search \ - --text-encoder-type BERT \ - --memory-layer 0 \ - --use-pre-text True \ - --use-style-prompt True \ - --max-prompt-lens 1000 \ - --style-text-transform mixed-punc \ - --pre-text-transform mixed-punc \ - --compute-CER 0 - - -(2) modified beam search -./zipformer_prompt_asr/decode_bert.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer_prompt_asr/exp \ - --max-duration 1000 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --text-encoder-type BERT \ - --memory-layer 0 \ - --use-pre-text True \ - --use-style-prompt True \ - --max-prompt-lens 1000 \ - --style-text-transform mixed-punc \ - --pre-text-transform mixed-punc \ - --compute-CER 0 - -(3) Decode LibriSpeech - -./zipformer_prompt_asr/decode_bert.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer_prompt_asr/exp \ - --max-duration 1000 \ - --decoding-method modified_beam_search \ - --use-ls-test-set True \ - --beam-size 4 \ - --text-encoder-type BERT \ - --memory-layer 0 \ - --use-pre-text True \ - --use-style-prompt True \ - --max-prompt-lens 1000 \ - --style-text-transform mixed-punc \ - --pre-text-transform mixed-punc \ - --compute-CER 0 - -(4) Decode LibriSpeech + biasing list - -biasing_list=100 # could also be 0 - -./zipformer_prompt_asr/decode_bert.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer_prompt_asr/exp \ - --max-duration 1000 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --use-ls-test-set True \ - --use-ls-context-list True \ - --biasing-level utterance \ - --ls-distractors $biasing_list \ - --post-normalization True \ - --text-encoder-type BERT \ - --max-prompt-lens 1000 \ - --style-text-transform mixed-punc \ - --pre-text-transform mixed-punc - - -""" - - -import argparse -import logging -import math -import warnings -from collections import defaultdict -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriHeavyAsrDataModule -from beam_search import greedy_search, greedy_search_batch, modified_beam_search -from dataset import naive_triplet_text_sampling, random_shuffle_subset -from ls_text_normalization import word_normalization -from text_normalization import ( - _apply_style_transform, - lower_all_char, - lower_only_alpha, - ref_text_normalization, - remove_non_alphabetic, - train_text_normalization, - upper_all_char, - upper_only_alpha, -) -from train_bert_encoder import ( - _encode_texts_as_bytes_with_tokenizer, - add_model_arguments, - get_params, - get_tokenizer, - get_transducer_model, -) -from transformers import BertModel, BertTokenizer -from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion - - modified_beam_search_LODR - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-pre-text", - type=str2bool, - default=True, - help="Use pre-text is available during decoding", - ) - - parser.add_argument( - "--use-style-prompt", - type=str2bool, - default=True, - help="Use style prompt when evaluation", - ) - - parser.add_argument( - "--max-prompt-lens", - type=int, - default=1000, - ) - - parser.add_argument( - "--post-normalization", - type=str2bool, - default=True, - help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", - ) - - parser.add_argument( - "--compute-CER", - type=str2bool, - default=False, - help="Reports CER. By default, only reports WER", - ) - - parser.add_argument( - "--style-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of style prompt, i.e style_text", - ) - - parser.add_argument( - "--pre-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of content prompt, i.e pre_text", - ) - - parser.add_argument( - "--use-ls-test-set", - type=str2bool, - default=False, - help="Use librispeech test set for evaluation.", - ) - - parser.add_argument( - "--use-ls-context-list", - type=str2bool, - default=False, - help="If use a fixed context list for LibriSpeech decoding", - ) - - parser.add_argument( - "--biasing-level", - type=str, - default="utterance", - choices=["utterance", "Book", "Chapter"], - ) - - parser.add_argument( - "--ls-distractors", - type=int, - default=0, - help="The number of distractors into context list for LibriSpeech decoding", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - tokenizer: spm.SentencePieceProcessor, - batch: dict, - biasing_dict: dict = None, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - tokenizer: - Tokenizer for the text encoder - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - biasing_dict: - A dictionary in the form `{cut_id: :w1 w2"}` that contains a list - of biasing words (separated with space) - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` - set to true. - ngram_lm: - A ngram lm. Used in LODR decoding. - ngram_lm_scale: - The scale of the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - cuts = batch["supervisions"]["cut"] - cut_ids = [c.supervisions[0].id for c in cuts] - batch_size = feature.size(0) - - if "pre_text" in batch["supervisions"] and params.use_pre_text: - pre_texts = batch["supervisions"]["pre_text"] - pre_texts = [train_text_normalization(t) for t in pre_texts] - else: - pre_texts = ["" for _ in range(batch_size)] - - # get the librispeech biasing data - if params.use_pre_text and (params.use_ls_context_list and params.use_ls_test_set): - if params.biasing_level == "utterance": - pre_texts = [biasing_dict[id] for id in cut_ids] - elif params.biasing_level == "Chapter": - chapter_ids = [c.split("-")[1] for c in cut_ids] - pre_texts = [biasing_dict[id] for id in chapter_ids] - elif params.biasing_level == "Book": - chapter_ids = [c.split("-")[1] for c in cut_ids] - pre_texts = [biasing_dict[id] for id in chapter_ids] - else: - raise ValueError(f"Unseen biasing level: {params.biasing_level}") - if params.pre_text_transform == "mixed-punc": - pre_texts = [t.lower() for t in pre_texts] - - # get style_text - if params.use_style_prompt: - fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." - style_texts = batch["supervisions"].get( - "style_text", [fixed_sentence for _ in range(batch_size)] - ) - style_texts = [train_text_normalization(t) for t in style_texts] - else: - style_texts = ["" for _ in range(batch_size)] # use empty string - - # Get the text embedding - if params.use_pre_text or params.use_style_prompt: - # apply style transform to the pre_text and style_text - pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) - if not params.use_ls_context_list: - pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts] - - if params.use_style_prompt: - style_texts = _apply_style_transform( - style_texts, params.style_text_transform - ) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Use tokenizer to prepare input for text encoder - encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( - pre_texts=pre_texts, - style_texts=style_texts, - tokenizer=tokenizer, - device=device, - no_limit=True, - ) - logging.info( - f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}" - ) - - memory, memory_key_padding_mask = model.encode_text( - encoded_inputs=encoded_inputs, - style_lens=style_lens, - ) # (T,B,C) - else: - memory = None - memory_key_padding_mask = None - - # Get the transducer encoder output - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - encoder_out, encoder_out_lens = model.encode_audio( - feature=feature, - feature_lens=feature_lens, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - - hyps = [] - - if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - tokenizer: spm.SentencePieceProcessor, - biasing_dict: Dict = None, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - tokenizer: - Tokenizer for the text encoder - biasing_dict: - A dictionary in the form `{cut_id: :w1 w2"}` that contains a list - of biasing words (separated with space) - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network LM, used during shallow fusion - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"][ - "text" - ] # By default, this should be in mixed-punc format - - # the style of ref_text should match style_text - texts = _apply_style_transform(texts, params.style_text_transform) - if params.use_style_prompt: - texts = _apply_style_transform(texts, params.style_text_transform) - - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - if not params.use_ls_test_set: - try: - book_names = [ - cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] - ] - except AttributeError: - book_names = [ - cut.id.split("/")[0] for cut in batch["supervisions"]["cut"] - ] - else: - book_names = ["" for _ in cut_ids] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - biasing_dict=biasing_dict, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, book_name, hyp_words, ref_text in zip( - cut_ids, book_names, hyps, texts - ): - ref_text = ref_text_normalization( - ref_text - ) # remove full-width symbols & some book marks - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], - biasing_words: List[str] = None, -): - test_set_wers = dict() - test_set_cers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - if params.compute_CER: - # Write CER statistics - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" - ) - store_transcripts(filename=recog_path, texts=results, char_level=True) - errs_filename = ( - params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - cer = write_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - compute_CER=params.compute_CER, - ) - test_set_cers[key] = cer - - logging.info("Wrote detailed CER stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - if params.compute_CER: - test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_cers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_cers: - s += "{} CER\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriHeavyAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "modified_beam_search", - ) - - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_pre_text: - params.suffix += ( - f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}" - ) - - if params.use_style_prompt: - params.suffix += f"-style-prompt-{params.style_text_transform}" - - if params.use_ls_context_list: - assert ( - params.use_pre_text - ), "Must set --use-pre-text to True if using context list" - params.suffix += f"-use-{params.biasing_level}-level-ls-context-list" - if params.biasing_level == "utterance" and params.ls_distractors: - params.suffix += f"-ls-context-distractors-{params.ls_distractors}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - tokenizer = get_tokenizer(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.to(device) - model.eval() - - LM = None - - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - libriheavy = LibriHeavyAsrDataModule(args) - - test_clean_cuts = libriheavy.test_clean_cuts() - test_other_cuts = libriheavy.test_other_cuts() - ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() - ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() - - test_clean_dl = libriheavy.valid_dataloaders( - test_clean_cuts, text_sampling_func=naive_triplet_text_sampling - ) - test_other_dl = libriheavy.valid_dataloaders( - test_other_cuts, text_sampling_func=naive_triplet_text_sampling - ) - ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) - ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) - - if params.use_ls_test_set: - test_sets = ["ls-test-clean", "ls-test-other"] - test_dl = [ls_test_clean_dl, ls_test_other_dl] - else: - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - biasing_dict = None - if params.use_ls_context_list: - if test_set == "ls-test-clean": - biasing_dict = get_facebook_biasing_list( - test_set="test-clean", - num_distractors=params.ls_distractors, - ) - elif test_set == "ls-test-other": - biasing_dict = get_facebook_biasing_list( - test_set="test-other", - num_distractors=params.ls_distractors, - ) - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - biasing_dict=biasing_dict, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - if params.post_normalization: - if "-post-normalization" not in params.suffix: - params.suffix += "-post-normalization" - - new_res = {} - for k in results_dict: - new_ans = [] - for item in results_dict[k]: - id, ref, hyp = item - if params.use_ls_test_set: - hyp = ( - " ".join(hyp).replace("-", " ").split() - ) # handle the hypens - hyp = upper_only_alpha(" ".join(hyp)).split() - hyp = [word_normalization(w.upper()) for w in hyp] - hyp = " ".join(hyp).split() - hyp = [w for w in hyp if w != ""] - ref = upper_only_alpha(" ".join(ref)).split() - else: - hyp = upper_only_alpha(" ".join(hyp)).split() - ref = upper_only_alpha(" ".join(ref)).split() - new_ans.append((id, ref, hyp)) - new_res[k] = new_ans - - save_results( - params=params, - test_set_name=test_set, - results_dict=new_res, - ) - - if params.suffix.endswith("-post-normalization"): - params.suffix = params.suffix.replace("-post-normalization", "") - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py deleted file mode 100755 index 4559ebb6d..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py +++ /dev/null @@ -1,963 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -""" - - -import argparse -import logging -import math -import warnings -from collections import defaultdict -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriHeavyAsrDataModule -from beam_search import ( - greedy_search, - greedy_search_batch, - greedy_search_batch_with_context, - greedy_search_with_context, - modified_beam_search, -) -from dataset import naive_triplet_text_sampling, random_shuffle_subset -from lhotse import load_manifest_lazy -from text_normalization import ( - lower_all_char, - lower_only_alpha, - ref_text_normalization, - remove_non_alphabetic, - train_text_normalization, - upper_all_char, - upper_only_alpha, -) -from train_bert_encoder_with_style import ( - _encode_texts_as_bytes_with_tokenizer, - add_model_arguments, - get_params, - get_tokenizer, - get_transducer_model, -) -from transformers import BertModel, BertTokenizer -from utils import get_facebook_biasing_list - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--log-dir", - type=str, - required=True, - help="Where to store the logs", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion - - modified_beam_search_LODR - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--input-manifest", - type=str, - required=True, - help="The input manifest to be decoded", - ) - - parser.add_argument( - "--output-manifest", - type=str, - required=True, - help="Where to store the output manifest (directory)", - ) - - parser.add_argument( - "--use-pre-text", - type=str2bool, - default=True, - help="Use pre-text is available during decoding", - ) - - parser.add_argument( - "--use-style-prompt", - type=str2bool, - default=True, - help="Use style prompt when evaluation", - ) - - parser.add_argument( - "--use-context-embedding", - type=str2bool, - default=False, - help="Use context fuser when evaluation", - ) - - parser.add_argument( - "--post-normalization", - type=str2bool, - default=True, - help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", - ) - - parser.add_argument( - "--compute-CER", - type=str2bool, - default=True, - help="Reports CER. By default, only reports WER", - ) - - parser.add_argument( - "--style-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of style prompt, i.e style_text", - ) - - parser.add_argument( - "--pre-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of content prompt, i.e pre_text", - ) - - parser.add_argument( - "--use-ls-test-set", - type=str2bool, - default=False, - help="Use librispeech test set for evaluation.", - ) - - parser.add_argument( - "--use-ls-context-list", - type=str2bool, - default=False, - help="If use a fixed context list for LibriSpeech decoding", - ) - - add_model_arguments(parser) - - return parser - - -def _apply_style_transform(text: List[str], transform: str) -> List[str]: - """Apply transform to a list of text. By default, the text are in - ground truth format, i.e mixed-punc. - - Args: - text (List[str]): Input text string - transform (str): Transform to be applied - - Returns: - List[str]: _description_ - """ - if transform == "mixed-punc": - return text - elif transform == "upper-no-punc": - return [upper_only_alpha(s) for s in text] - elif transform == "lower-no-punc": - return [lower_only_alpha(s) for s in text] - elif transform == "lower-punc": - return [lower_all_char(s) for s in text] - else: - raise NotImplementedError(f"Unseen transform: {transform}") - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - tokenizer, - batch: dict, - biasing_dict: dict = None, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` - set to true. - ngram_lm: - A ngram lm. Used in LODR decoding. - ngram_lm_scale: - The scale of the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - cuts = batch["supervisions"]["cut"] - cut_ids = [c.supervisions[0].id for c in cuts] - batch_size = feature.size(0) - - # get pre_text - if "pre_text" in batch["supervisions"] and params.use_pre_text: - pre_texts = batch["supervisions"][ - "text" - ] # use the ground truth ref text as pre_text - pre_texts = [train_text_normalization(t) for t in pre_texts] - else: - pre_texts = ["" for _ in range(batch_size)] - - if params.use_ls_context_list: - pre_texts = [biasing_dict[id] for id in cut_ids] - - # get style_text - if params.use_style_prompt: - fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." - style_texts = batch["supervisions"].get( - "style_text", [fixed_sentence for _ in range(batch_size)] - ) - style_texts = [train_text_normalization(t) for t in style_texts] - else: - style_texts = ["" for _ in range(batch_size)] # use empty string - - # Get the text embedding input - if params.use_pre_text or params.use_style_prompt: - - # apply style transform to the pre_text and style_text - pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) - # pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) - if params.use_style_prompt: - style_texts = _apply_style_transform( - style_texts, params.style_text_transform - ) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Use tokenizer to prepare input for text encoder - encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( - pre_texts=pre_texts, - style_texts=style_texts, - tokenizer=tokenizer, - device=device, - ) - - memory, memory_key_padding_mask = model.encode_text( - encoded_inputs=encoded_inputs, - style_lens=style_lens, - ) # (T,B,C) - else: - memory = None - memory_key_padding_mask = None - - # Get the transducer encoder output - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - encoder_out, encoder_out_lens = model.encode_audio( - feature=feature, - feature_lens=feature_lens, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - - hyps = [] - - if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - if memory is None or not params.use_context_embedding: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - else: - memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) - context = model.context_fuser( - memory, padding_mask=memory_key_padding_mask - ) # (N,C) - context = model.joiner.context_proj(context) # (N,C) - hyp_tokens = greedy_search_batch_with_context( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context=context, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - if memory is None or not params.use_context_embedding: - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - else: - cur_context = context[i : i + 1, :] - hyp = greedy_search_with_context( - model=model, - encoder_out=encoder_out_i, - context=cur_context, - max_sym_per_frame=params.max_sym_per_frame, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - tokenizer, - biasing_dict: Dict = None, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network LM, used during shallow fusion - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 40 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"][ - "text" - ] # By default, this should be in mixed-punc format - - # the style of ref_text should match style_text - texts = _apply_style_transform(texts, params.style_text_transform) - if params.use_style_prompt: - texts = _apply_style_transform(texts, params.style_text_transform) - - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - biasing_dict=biasing_dict, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = ref_text_normalization( - ref_text - ) # remove full-width symbols & some book marks - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - test_set_cers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - if params.compute_CER: - # Write CER statistics - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" - ) - store_transcripts(filename=recog_path, texts=results, char_level=True) - errs_filename = ( - params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - cer = write_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - compute_CER=params.compute_CER, - ) - test_set_cers[key] = cer - - logging.info("Wrote detailed CER stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - if params.compute_CER: - test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_cers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_cers: - s += "{} CER\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -def add_decoding_result_to_manifest( - in_manifest, - out_manifest: str, - results_dict: Dict, -): - # write the decoding results with prompt to the manifest as an - # extra ref text - new_ans = {} - for key, value in results_dict.items(): - for items in value: - id, ref, hyp = items - new_ans[id] = " ".join(hyp) - - def _add_decoding(c): - key = c.supervisions[0].id - c.supervisions[0].texts.append(new_ans[key]) - return c - - in_manifest = in_manifest.map(_add_decoding) - logging.info(f"Saving manifest to {out_manifest}") - in_manifest.to_file(out_manifest) - - -def main(): - parser = get_parser() - LibriHeavyAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - cuts = load_manifest_lazy(args.input_manifest) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - splitted_cuts = cuts.split(num_splits=world_size) - mp.spawn( - run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True - ) - else: - run(rank=0, world_size=1, args=args, cuts=cuts) - - -@torch.no_grad() -def run(rank, world_size, args, cuts): - params = get_params() - params.update(vars(args)) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.use_pre_text: - params.suffix += f"-pre-text-{params.pre_text_transform}" - - if params.use_style_prompt: - params.suffix += f"-style-prompt-{params.style_text_transform}" - - params.suffix += f"-{rank}" - - world_size = params.world_size - - params.output_manifest = Path(params.output_manifest) - if world_size > 1: - cuts = cuts[rank] - out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz" - else: - out_name = params.output_manifest / "with_decoding.jsonl.gz" - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}") - logging.info("Decoding started") - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - tokenizer = get_tokenizer(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - LM = None - - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - libriheavy = LibriHeavyAsrDataModule(args) - - dl = libriheavy.valid_dataloaders( - cuts, text_sampling_func=naive_triplet_text_sampling - ) - - test_sets = ["test"] - test_dl = [dl] - - for test_set, test_dl in zip(test_sets, test_dl): - biasing_dict = None - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - biasing_dict=biasing_dict, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - # save_results( - # params=params, - # test_set_name=test_set, - # results_dict=results_dict, - # ) - - add_decoding_result_to_manifest( - in_manifest=cuts, - out_manifest=out_name, - results_dict=results_dict, - ) - - logging.info("Done!") - - -# torch.set_num_threads(1) -# torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py deleted file mode 100644 index 91f167204..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from scaling import Balancer - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - padding_idx=blank_id, - ) - # the balancers are to avoid any drift in the magnitude of the - # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer( - decoder_dim, - channel_dim=-1, - min_positive=0.0, - max_positive=1.0, - min_abs=0.5, - max_abs=1.0, - prob=0.05, - ) - - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim // 4, # group size == 4 - bias=False, - ) - self.balancer2 = Balancer( - decoder_dim, - channel_dim=-1, - min_positive=0.0, - max_positive=1.0, - min_abs=0.5, - max_abs=1.0, - prob=0.05, - ) - else: - # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` - # when inference with torch.jit.script and context_size == 1 - self.conv = nn.Identity() - self.balancer2 = nn.Identity() - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - - embedding_out = self.balancer(embedding_out) - - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - embedding_out = self.balancer2(embedding_out) - - return embedding_out diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py b/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py deleted file mode 100644 index 257facce4..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn - - -class EncoderInterface(nn.Module): - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (batch_size, input_seq_len, num_features) - containing the input features. - x_lens: - A tensor of shape (batch_size,) containing the number of frames - in `x` before padding. - Returns: - Return a tuple containing two tensors: - - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - containing unnormalized probabilities, i.e., the output of a - linear layer. - - encoder_out_lens, a tensor of shape (batch_size,) containing - the number of frames in `encoder_out` before padding. - """ - raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py b/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py deleted file mode 100644 index e0bc556a8..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. - -""" -Export `model.state_dict()` - -- For non-streaming model: - -./zipformer_prompt_asr/export_PromptASR.py \ - --exp-dir ./zipformer_prompt_asr/exp \ - --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ - --epoch 50 \ - --avg 10 - -- For streaming model: - -./zipformer_prompt_asr/export_PromptASR.py \ - --exp-dir ./zipformer_prompt_asr/exp \ - --causal 1 \ - --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ - --epoch 50 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -""" - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -from torch import Tensor, nn -from train_bert_encoder import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named jit_script.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - assert params.jit is False, "Jit is not supported yet" - - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py b/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py deleted file mode 100644 index 59f822748..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - context_dim: int = 512, - context_injection: bool = False, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) - self.output_linear = nn.Linear(joiner_dim, vocab_size) - if context_injection: - self.context_proj = ScaledLinear( - context_dim, joiner_dim, initial_scale=0.25 - ) - else: - self.context_proj = None - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - context: torch.Tensor = None, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - context: - An embedding vector representing the previous context information - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - assert encoder_out.ndim == decoder_out.ndim == 4 - assert encoder_out.shape[:-1] == decoder_out.shape[:-1] - - if project_input: - if context: - logit = ( - self.encoder_proj(encoder_out) - + self.decoder_proj(decoder_out) - + self.context_proj(context) - ) - else: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) - else: - if context is not None: - logit = encoder_out + decoder_out + context.unsqueeze(1).unsqueeze(1) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py deleted file mode 100644 index 9a693ca4f..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py +++ /dev/null @@ -1,153 +0,0 @@ -import re - -words = { - 0: "zero", - 1: "one", - 2: "two", - 3: "three", - 4: "four", - 5: "five", - 6: "six", - 7: "seven", - 8: "eight", - 9: "nine", - 10: "ten", - 11: "eleven", - 12: "twelve", - 13: "thirteen", - 14: "fourteen", - 15: "fifteen", - 16: "sixteen", - 17: "seventeen", - 18: "eighteen", - 19: "nineteen", - 20: "twenty", - 30: "thirty", - 40: "forty", - 50: "fifty", - 60: "sixty", - 70: "seventy", - 80: "eighty", - 90: "ninety", -} -ordinal_nums = [ - "zeroth", - "first", - "second", - "third", - "fourth", - "fifth", - "sixth", - "seventh", - "eighth", - "ninth", - "tenth", - "eleventh", - "twelfth", - "thirteenth", - "fourteenth", - "fifteenth", - "sixteenth", - "seventeenth", - "eighteenth", - "nineteenth", - "twentieth", -] - -num_ordinal_dict = {num: ordinal_nums[num] for num in range(21)} - - -def year_to_words(num: int): - assert isinstance(num, int), num - # check if a num is representing a year - if num > 1500 and num < 2000: - return words[num // 100] + " " + num_to_words(num % 100) - elif num == 2000: - return "TWO THOUSAND" - elif num > 2000: - return "TWO THOUSAND AND " + num_to_words(num % 100) - else: - return num_to_words(num) - - -def num_to_words(num: int): - # Return the English words of a integer number - - # If this is a year number - if num > 1500 and num < 2030: - return year_to_words(num) - - if num < 20: - return words[num] - if num < 100: - if num % 10 == 0: - return words[num // 10 * 10] - else: - return words[num // 10 * 10] + " " + words[num % 10] - if num < 1000: - return words[num // 100] + " hundred and " + num_to_words(num % 100) - if num < 1000000: - return num_to_words(num // 1000) + " thousand " + num_to_words(num % 1000) - return num - - -def num_to_ordinal_word(num: int): - - return num_ordinal_dict.get(num, num_to_words(num)).upper() - - -def replace_full_width_symbol(s: str) -> str: - # replace full-width symbol with theri half width counterpart - s = s.replace("“", '"') - s = s.replace("”", '"') - s = s.replace("‘", "'") - s = s.replace("’", "'") - - return s - - -def decoding_normalization(text: str) -> str: - text = replace_full_width_symbol(text) - - # Only keep all alpha-numeric characters, hypen and apostrophe - text = text.replace("-", " ") - text = re.sub(r"[^a-zA-Z0-9\s']+", "", text) - return text - - -def word_normalization(word: str) -> str: - # 1 .Use full word for some abbreviation - # 2. Convert digits to english words - # 3. Convert ordinal number to english words - if word == "MRS": - return "MISSUS" - if word == "MR": - return "MISTER" - if word == "ST": - return "SAINT" - if word == "ECT": - return "ET CETERA" - if word.isnumeric(): - word = num_to_words(int(word)) - return str(word).upper() - # e.g 9TH, 6TH - if word[-2:] == "TH" and word[0].isnumeric(): - return num_to_ordinal_word(int(word[:-2])).upper() - if word[0] == "'": - return word[1:] - - return word - - -def simple_normalization(text: str) -> str: - text = replace_full_width_symbol(text) - text = text.replace("--", " ") - - return text - - -if __name__ == "__main__": - - s = str(1830) - out = word_normalization(s) - print(s, out) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py deleted file mode 100644 index 77b4057c4..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import random -import warnings -from typing import Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear, penalize_abs_values_gt -from torch import Tensor - -from icefall.utils import add_sos, make_pad_mask - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder_embed = encoder_embed - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, - vocab_size, - initial_scale=0.25, - ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, - vocab_size, - initial_scale=0.25, - ) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - text: - A 2-D tensor of integer dtype containing prompt text, of shape (N, T). - It is exptected to contain the style prompt (first) and then the content - prompt. - text_lens: - A 1-D tensor of shape (N,). It contains the number of elements (bytes) - in `text` before padding, which will include the lengths of the - style plus the content prompt. - style_lens: - A 1-D tensor of shape (N,), containing the number of elements (bytes) - within each row of `text` that correspond to the style prompt (these - are expected to come first). - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - x, x_lens = self.encoder_embed(x, x_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, x_lens = self.encoder( - x, - x_lens, - src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - return (simple_loss, pruned_loss) - - def encode_audio( - self, - feature: Tensor, - feature_lens: Tensor, - memory: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """Encode the input audio features - - Args: - feature (Tensor): Input audio (N,T,C) - feature_lens (Tensor): Length of input audio (N,) - Returns: - Tuple[Tensor, Tensor]: Encoded acoustic features and length - """ - x, x_lens = self.encoder_embed(feature, feature_lens) - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder( - x=x, - x_lens=x_lens, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py deleted file mode 100644 index 21c7b4fac..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py +++ /dev/null @@ -1,392 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import random -import warnings -from typing import Dict, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear, penalize_abs_values_gt -from torch import Tensor - -from icefall.utils import add_sos, make_pad_mask - - -class PromptedTransducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - text_encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - use_BERT: bool = True, - text_encoder_type: str = "BERT", - text_encoder_adapter: bool = False, - freeze_text_encoder: bool = True, - context_fuser: nn.Module = None, - ): - """ - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - text_encoder: - This is a encoder that processes text information (e.g content prompt - and style prompt). The input is `x` of (N,T) and `x_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - text_encoder_type: - The type of the text_encoder. Supported are (BERT, DistilBERT) - context_fuser - A optional module that fuses the embeddings of text encoder. The fused embedding - will be added to the joiner. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder_embed = encoder_embed - self.encoder = encoder - self.text_encoder = text_encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, - vocab_size, - initial_scale=0.25, - ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, - vocab_size, - initial_scale=0.25, - ) - - self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT - self.context_fuser = context_fuser - - assert text_encoder_type in ( - "BERT", - "DistilBERT", - "BERT-UNCASED", - ), f"Unseen text_encoder type {text_encoder_type}" - self.text_encoder_dim = ( - self.text_encoder.config.hidden_size - if text_encoder_type in ("BERT", "BERT-UNCASED") - else self.text_encoder.config.dim - ) - self.freeze_text_encoder = freeze_text_encoder - - if text_encoder_adapter: - self.text_encoder_adapter = nn.Sequential( - nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False), - nn.Tanh(), - ) - else: - self.text_encoder_adapter = None - - self.style_prompt_embedding = nn.Parameter( - torch.full((self.text_encoder_dim,), 0.5) - ) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - encoded_inputs: Dict, - style_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - use_pre_text: bool = True, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - text: - A 2-D tensor of integer dtype containing prompt text, of shape (N, T). - It is exptected to contain the style prompt (first) and then the content - prompt. - text_lens: - A 1-D tensor of shape (N,). It contains the number of elements (bytes) - in `text` before padding, which will include the lengths of the - style plus the content prompt. - style_lens: - A 1-D tensor of shape (N,), containing the number of elements (bytes) - within each row of `text` that correspond to the style prompt (these - are expected to come first). - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - if self.freeze_text_encoder: - self.text_encoder.eval() - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - x, x_lens = self.encoder_embed(x, x_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # freeze the BERT text encoder - - if use_pre_text: - memory, memory_key_padding_mask = self.encode_text( - encoded_inputs, style_lens=style_lens - ) - else: - memory = None - memory_key_padding_mask = None - - encoder_out, x_lens = self.encoder( - x, - x_lens, - src_key_padding_mask, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - if self.context_fuser is not None and memory is not None: - memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) - context = self.context_fuser(memory, padding_mask=memory_key_padding_mask) - context = self.joiner.context_proj(context) - else: - context = None - - logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - return (simple_loss, pruned_loss) - - def _add_style_indicator(self, memory: Tensor, style_lens: Tensor): - """ - Adds to `memory` an indicator that is 1.0 for positions that correspond to - the `style prompt` and 0 elsewhere. The scale can be fixed because the - scale of the embedding vector can adjust to compensate. - - Args: - memory: (memory_len, batch_size, embed_dim) - style_lens: (batch_size,), a vector of lengths of the style prompt. - """ - - (memory_len, batch_size, embed_dim) = memory.shape - - indicator = ( - torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens - ) - indicator = indicator.to(memory.dtype) - - extra_term = torch.zeros_like(memory) - extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand( - memory_len, batch_size, self.text_encoder_dim - ) - - return memory + extra_term - - def encode_text( - self, - encoded_inputs: Dict, - style_lens: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Get the embeddings of text - - Args: - encoded_inputs: The encoded inputs generated by a tokenizer (Dict) - - Returns: - Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the - text_encoder and the attention mask - """ - text_lens = encoded_inputs.pop("length") # need to use pop to remove this item - - # Freeze the pre-trained text encoder - with torch.no_grad(): - memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C) - memory = memory.permute(1, 0, 2) - - # Text encoder adapter - if self.text_encoder_adapter is not None: - memory = self.text_encoder_adapter(memory) - - memory = self._add_style_indicator(memory, style_lens) - - memory_key_padding_mask = make_pad_mask(text_lens) - - return memory, memory_key_padding_mask - - def encode_audio( - self, - feature: Tensor, - feature_lens: Tensor, - memory: Optional[Tensor], - memory_key_padding_mask: Optional[Tensor], - ) -> Tuple[Tensor, Tensor]: - """Encode the input audio features - - Args: - feature (Tensor): Input audio (N,T,C) - feature_lens (Tensor): Length of input audio (N,) - memory (Tensor): Embeddings from the text encoder - memory_key_padding_mask (Tensor): _description_ - - Returns: - Tuple[Tensor, Tensor]: _description_ - """ - x, x_lens = self.encoder_embed(feature, feature_lens) - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder( - x=x, - x_lens=x_lens, - src_key_padding_mask=src_key_padding_mask, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -Transducer = PromptedTransducer # for decoding diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py deleted file mode 100644 index 159e363c7..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py +++ /dev/null @@ -1,1164 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import logging -import random -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import torch -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.optim import Optimizer - - -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): - """ - This function returns (technically, yields) a list of - of tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this batch of parameters - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - This function is decorated as a context manager so that it can - write parameters back to their "real" locations. - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.param_groups. - group_params_names: name for each parameter in group, - which is List[str]. - """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): - key = (str(p.dtype), *p.shape) - batches[key].append(p) - batches_names[key].append(named_p) - - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - - stacked_params_dict = dict() - - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), - # one for each batch in `batches`. - tuples = [] - - for batch, batch_names in zip(batches, batches_names): - p = batch[0] - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) - - yield tuples # <-- calling code will do the actual optimization here! - - for (stacked_params, _state, _names), batch in zip(tuples, batches): - for i, p in enumerate(batch): # batch is list of Parameter - p.copy_(stacked_params[i]) - - -class ScaledAdam(BatchedOptimizer): - """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - Unlike common optimizers, which accept model.parameters() or groups of parameters(), - this optimizer could accept model.named_parameters() or groups of named_parameters(). - See comments of function _get_names_of_parameters for its 4 possible cases. - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - """ - - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): - defaults = dict( - lr=lr, - clipping_scale=clipping_scale, - betas=betas, - scalar_lr_scale=scalar_lr_scale, - eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, - scalar_max=scalar_max, - size_update_period=size_update_period, - clipping_update_period=clipping_update_period, - ) - - # If params only contains parameters or group of parameters, - # i.e when parameter names are not given, - # this flag will be set to False in funciton _get_names_of_parameters. - self.show_dominant_parameters = True - param_groups, parameters_names = self._get_names_of_parameters(params) - super(ScaledAdam, self).__init__(param_groups, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - - def _get_names_of_parameters( - self, params_or_named_params - ) -> Tuple[List[Dict], List[List[str]]]: - """ - Args: - params_or_named_params: according to the way ScaledAdam is initialized in train.py, - this argument could be one of following 4 cases, - case 1, a generator of parameter, e.g.: - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 2, a list of parameter groups with different config, e.g.: - model_param_groups = [ - {'params': model.encoder.parameters(), 'lr': 0.05}, - {'params': model.decoder.parameters(), 'lr': 0.01}, - {'params': model.joiner.parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) - - case 3, a generator of named_parameter, e.g.: - optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 4, a list of named_parameter groups with different config, e.g.: - model_named_param_groups = [ - {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, - {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, - {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) - - For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. - For case 3 and case 4, firstly, names and params are extracted from input named_params, - then, these extracted params are used to initialize the underlying torch.optimizer, - and these extracted names are mainly used by function - `_show_gradient_dominating_parameter` - - Returns: - Returns a tuple containing 2 elements: - - `param_groups` with type List[Dict], each Dict element is a parameter group. - An example of `param_groups` could be: - [ - {'params': `one iterable of Parameter`, 'lr': 0.05}, - {'params': `another iterable of Parameter`, 'lr': 0.08}, - {'params': `a third iterable of Parameter`, 'lr': 0.1}, - ] - - `param_gruops_names` with type List[List[str]], - each `List[str]` is for a group['params'] in param_groups, - and each `str` is the name of a parameter. - A dummy name "foo" is related to each parameter, - if input are params without names, i.e. case 1 or case 2. - """ - # variable naming convention in this function: - # p is short for param. - # np is short for named_param. - # p_or_np is short for param_or_named_param. - # cur is short for current. - # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. - # groups is a List[group] - - iterable_or_groups = list(params_or_named_params) - if len(iterable_or_groups) == 0: - raise ValueError("optimizer got an empty parameter list") - - # The first value of returned tuple. A list of dicts containing at - # least 'params' as a key. - param_groups = [] - - # The second value of returned tuple, - # a List[List[str]], each sub-List is for a group. - param_groups_names = [] - - if not isinstance(iterable_or_groups[0], dict): - # case 1 or case 3, - # the input is an iterable of parameter or named parameter. - param_iterable_cur_group = [] - param_names_cur_group = [] - for p_or_np in iterable_or_groups: - if isinstance(p_or_np, tuple): - # case 3 - name, param = p_or_np - else: - # case 1 - assert isinstance(p_or_np, torch.Tensor) - param = p_or_np - # Assign a dummy name as a placeholder - name = "foo" - self.show_dominant_parameters = False - param_iterable_cur_group.append(param) - param_names_cur_group.append(name) - param_groups.append({"params": param_iterable_cur_group}) - param_groups_names.append(param_names_cur_group) - else: - # case 2 or case 4 - # the input is groups of parameter or named parameter. - for cur_group in iterable_or_groups: - assert "named_params" in cur_group - name_list = [x[0] for x in cur_group["named_params"]] - p_list = [x[1] for x in cur_group["named_params"]] - del cur_group["named_params"] - cur_group["params"] = p_list - param_groups.append(cur_group) - param_groups_names.append(name_list) - - return param_groups, param_groups_names - - def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - batch = True - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - - for p, state, _ in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - - self._step_one_batch(group, p, state, clipping_scale) - - return loss - - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - numel = p.numel() - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - def _get_clipping_scale( - self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] - ) -> float: - """ - Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients - by this amount before applying the rest of the update. - - Args: - group: the parameter group, an item in self.param_groups - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - assert len(tuples) >= 1 - clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = tuples[0] - step = first_state["step"] - if clipping_scale is None or step == 0: - # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialized yet. - return 1.0 - clipping_update_period = group["clipping_update_period"] - - tot_sumsq = torch.tensor(0.0, device=first_p.device) - for p, state, param_names in tuples: - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] - else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() - - tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) - first_state["model_norms"][step % clipping_update_period] = tot_norm - - if step % clipping_update_period == 0: - # Print some stats. - # We don't reach here if step == 0 because we would have returned - # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") - quartiles = [] - for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) - quartiles.append(sorted_norms[index].item()) - - median = quartiles[2] - threshold = clipping_scale * median - first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) - first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) - - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - return ans - - def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor - ): - """ - Show information of parameter which dominates tot_sumsq. - - Args: - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - tot_sumsq: sumsq of all parameters. Though it's could be calculated - from tuples, we still pass it to save some time. - """ - all_sumsq_orig = {} - for p, state, batch_param_names in tuples: - # p is a stacked batch parameters. - batch_grad = p.grad - if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 - # Dummy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) - else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( - dim=list(range(1, batch_grad.ndim)) - ) - - for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad - ): - proportion_orig = sumsq_orig / tot_sumsq - all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) - sorted_by_proportion = { - k: v - for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True - ) - } - dominant_param_name = next(iter(sorted_by_proportion)) - ( - dominant_proportion, - dominant_sumsq, - dominant_rms, - dominant_grad, - ) = sorted_by_proportion[dominant_param_name] - logging.info( - f"Parameter dominating tot_sumsq {dominant_param_name}" - f" with proportion {dominant_proportion:.2f}," - f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" - f"={dominant_sumsq:.3e}," - f" grad_sumsq={(dominant_grad**2).sum():.3e}," - f" orig_rms_sq={(dominant_rms**2).item():.3e}" - ) - - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad = grad * clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[Union[int, float]] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.info( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -# This is included mostly as a baseline for ScaledAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - - -def _test_scaled_adam(hidden_dim: int): - import timeit - - from scaling import ScaledLinear - - E = 100 - B = 4 - T = 2 - logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - for iter in [1, 0]: - fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(180): - scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 512 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.log().backward() - optim.step() - optim.zero_grad() - scheduler.step_batch() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - logging.getLogger().setLevel(logging.INFO) - import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) - logging.info(s) - import sys - - if len(sys.argv) > 1: - hidden_dim = int(sys.argv[1]) - else: - hidden_dim = 200 - - _test_scaled_adam(hidden_dim) - _test_eden() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py deleted file mode 100644 index 458109a3f..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py +++ /dev/null @@ -1,360 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script loads a checkpoint (`pretrained.pt`) and uses it to decode waves. -You can generate the checkpoint with the following command: - -./zipformer/export_PromptASR.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ - --epoch 50 \ - --avg 10 - -Utterance level context biasing: - -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ - --method modified_beam_search \ - --use-pre-text True \ - --content-prompt "bessy random words hello k2 ASR" \ - --use-style-prompt True \ - librispeech.flac - - -Word level context biasing: - -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ - --method modified_beam_search \ - --use-pre-text True \ - --content-prompt "The topic is about horses." \ - --use-style-prompt True \ - test.wav - - -""" - -import argparse -import logging -import math -import warnings -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import greedy_search_batch, modified_beam_search -from text_normalization import _apply_style_transform, train_text_normalization -from torch.nn.utils.rnn import pad_sequence -from train_bert_encoder import ( - _encode_texts_as_bytes_with_tokenizer, - add_model_arguments, - get_params, - get_tokenizer, - get_transducer_model, -) - -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500_fallback_coverage_0.99/bpe.model", - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - parser.add_argument( - "--use-pre-text", - type=str2bool, - default=True, - help="Use content prompt during decoding", - ) - - parser.add_argument( - "--use-style-prompt", - type=str2bool, - default=True, - help="Use style prompt during decoding", - ) - - parser.add_argument( - "--pre-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of content prompt, i.e pre_text", - ) - - parser.add_argument( - "--style-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of style prompt, i.e style_text", - ) - - parser.add_argument( - "--content-prompt", type=str, default="", help="The content prompt for decoding" - ) - - parser.add_argument( - "--style-prompt", - type=str, - default="Mixed-cased English text with punctuations, feel free to change it.", - help="The style prompt for decoding", - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_transducer_model(params) - tokenizer = get_tokenizer(params) # for text encoder - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - assert ( - len(params.sound_files) == 1 - ), "Only support decoding one audio at this moment" - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # encode prompts - if params.use_pre_text: - pre_text = [train_text_normalization(params.content_prompt)] - pre_text = _apply_style_transform(pre_text, params.pre_text_transform) - else: - pre_text = [""] - - if params.use_style_prompt: - style_text = [params.style_prompt] - style_text = _apply_style_transform(style_text, params.style_text_transform) - else: - style_text = [""] - - if params.use_pre_text or params.use_style_prompt: - encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( - pre_texts=pre_text, - style_texts=style_text, - tokenizer=tokenizer, - device=device, - no_limit=True, - ) - - memory, memory_key_padding_mask = model.encode_text( - encoded_inputs=encoded_inputs, - style_lens=style_lens, - ) # (T,B,C) - else: - memory = None - memory_key_padding_mask = None - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - encoder_out, encoder_out_lens = model.encode_audio( - feature=features, - feature_lens=feature_lengths, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - if params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - hyps.append(sp.decode(hyp_tokens)[0]) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - hyps.append(sp.decode(hyp_tokens)[0]) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py deleted file mode 100644 index 0e6764ba0..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py +++ /dev/null @@ -1,1872 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import logging -import math -import random -from functools import reduce -from itertools import repeat -from typing import Optional, Tuple, Union - -import k2 -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.nn import Embedding as ScaledEmbedding - - -class PiecewiseLinear(object): - """ - Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with - the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] - respectively. - """ - - def __init__(self, *args): - assert len(args) >= 1 - if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - self.pairs = list(args[0].pairs) - else: - self.pairs = [(float(x), float(y)) for x, y in args] - for (x, y) in self.pairs: - assert isinstance(x, float) or isinstance(x, int) - assert isinstance(y, float) or isinstance(y, int) - - for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], self.pairs - - def __str__(self): - # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f"PiecewiseLinear({str(self.pairs)[1:-1]})" - - def __call__(self, x): - if x <= self.pairs[0][0]: - return self.pairs[0][1] - elif x >= self.pairs[-1][0]: - return self.pairs[-1][1] - else: - cur_x, cur_y = self.pairs[0] - for i in range(1, len(self.pairs)): - next_x, next_y = self.pairs[i] - if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) - cur_x, cur_y = next_x, next_y - assert False - - def __mul__(self, alpha): - return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) - s, x = self.get_common_basis(x) - return PiecewiseLinear( - *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def min(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def __eq__(self, other): - return self.pairs == other.pairs - - def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): - """ - Returns (self_mod, p_mod) which are equivalent piecewise lienar - functions to self and p, but with the same x values. - - p: the other piecewise linear function - include_crossings: if true, include in the x values positions - where the functions indicate by this and p crosss. - """ - assert isinstance(p, PiecewiseLinear) - - # get sorted x-values without repetition. - x_vals = sorted(set([x for x, y in self.pairs] + [x for x, y in p.pairs])) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - - if include_crossings: - extra_x_vals = [] - for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): - # if the two lines in this subsegment potentially cross each other.. - diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) - # `pos`, between 0 and 1, gives the relative x position, - # with 0 being x_vals[i] and 1 being x_vals[i+1]. - pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) - extra_x_vals.append(extra_x_val) - if len(extra_x_vals) > 0: - x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - return ( - PiecewiseLinear(*zip(x_vals, y_vals1)), - PiecewiseLinear(*zip(x_vals, y_vals2)), - ) - - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or in training or mode or in - torch.jit scripting mode. - """ - - def __init__(self, *args, default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - self.schedule = PiecewiseLinear(*args) - - def extra_repr(self) -> str: - return ( - f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" - ) - - def __float__(self): - batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting(): - return float(self.default) - else: - ans = self.schedule(self.batch_count) - if random.random() < 0.0002: - logging.info( - f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" - ) - return ans - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, default=self.default) - else: - return ScheduledFloat( - self.schedule + x.schedule, default=self.default + x.default - ) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), default=self.default) - else: - return ScheduledFloat( - self.schedule.max(x.schedule), default=max(self.default, x.default) - ) - - -FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class CutoffEstimator: - """ - Estimates cutoffs of an arbitrary numerical quantity such that a specified - proportion of items will be above the cutoff on average. - - p is the proportion of items that should be above the cutoff. - """ - - def __init__(self, p: float): - self.p = p - # total count of items - self.count = 0 - # total count of items that were above the cutoff - self.count_above = 0 - # initial cutoff value - self.cutoff = 0 - - def __call__(self, x: float) -> bool: - """ - Returns true if x is above the cutoff. - """ - ans = x > self.cutoff - self.count += 1 - if ans: - self.count_above += 1 - cur_p = self.count_above / self.count - delta_p = cur_p - self.p - if (delta_p > 0) == ans: - q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1 - q) - return ans - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class BiasNormFunction(torch.autograd.Function): - # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return (x - bias) * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). - @staticmethod - def forward( - ctx, - x: Tensor, - bias: Tensor, - log_scale: Tensor, - channel_dim: int, - store_output_for_backprop: bool, - ) -> Tensor: - assert bias.ndim == 1 - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop - ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() - ans = x * scales - ctx.save_for_backward( - ans.detach() if store_output_for_backprop else x, - scales.detach(), - bias.detach(), - log_scale.detach(), - ) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = ( - torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None - - -class BiasNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization. We also give it a (scalar) - trainable scale on the output. - - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - log_scale: the initial log-scale that we multiply the output by; this - is learnable. - log_scale_min: FloatLike, minimum allowed value of log_scale - log_scale_max: FloatLike, maximum allowed value of log_scale - store_output_for_backprop: only possibly affects memory use; recommend - to set to True if you think the output of this module is more likely - than the input of this module to be required to be stored for the - backprop. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False, - ) -> None: - super(BiasNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - - self.log_scale_min = log_scale_min - self.log_scale_max = log_scale_max - - self.store_output_for_backprop = store_output_for_backprop - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - - if torch.jit.is_scripting(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 - ) * self.log_scale.exp() - return x * scales - - log_scale = limit_param_value( - self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training, - ) - - return BiasNormFunction.apply( - x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop - ) - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: - """ - Behaves like a constructor of a modified version of nn.Conv2d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False, but: - NO PADDING-RELATED ARGS. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv2d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -class ChunkCausalDepthwiseConv1d(torch.nn.Module): - """ - Behaves like a depthwise 1d convolution, except that it is causal in - a chunkwise way, as if we had a block-triangular attention mask. - The chunk size is provided at test time (it should probably be - kept in sync with the attention mask). - - This has a little more than twice the parameters of a conventional - depthwise conv1d module: we implement it by having one - depthwise convolution, of half the width, that is causal (via - right-padding); and one depthwise convolution that is applied only - within chunks, that we multiply by a scaling factor which depends - on the position within the chunk. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - - def __init__( - self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True, - ): - super().__init__() - assert kernel_size % 2 == 1 - - half_kernel_size = (kernel_size + 1) // 2 - # will pad manually, on one side. - self.causal_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True, - ) - - self.chunkwise_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias, - ) - - # first row is correction factors added to the scale near the left edge of the chunk, - # second row is correction factors added to the scale near the right edge of the chunk, - # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) - self.kernel_size = kernel_size - - with torch.no_grad(): - self.causal_conv.weight[:] *= initial_scale - self.chunkwise_conv.weight[:] *= initial_scale - if bias: - torch.nn.init.uniform_( - self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale - ) - - def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. - """ - (batch_size, num_channels, seq_len) = x.shape - - half_kernel_size = self.kernel_size + 1 // 2 - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - if chunk_size < 0 or chunk_size > seq_len: - chunk_size = seq_len - right_pad = -seq_len % chunk_size - - x = torch.nn.functional.pad(x, (left_pad, right_pad)) - - x_causal = self.causal_conv(x[..., : left_pad + seq_len]) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( - batch_size * num_chunks, num_channels, chunk_size - ) - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size) - - x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape( - batch_size, num_chunks, num_channels, chunk_size - ).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ - ..., :seq_len - ] - - return x_chunk + x_causal - - def _get_chunk_scale(self, chunk_size: int): - """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" - left_edge = self.chunkwise_conv_scale[0] - right_edge = self.chunkwise_conv_scale[1] - if chunk_size < self.kernel_size: - left_edge = left_edge[:, :chunk_size] - right_edge = right_edge[:, -chunk_size:] - else: - t = chunk_size - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros( - channels, t, device=left_edge.device, dtype=left_edge.dtype - ) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) - return 1.0 + (left_edge + right_edge) - - -class BalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - (x,) = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [i for i in range(x.ndim) if i != channel_dim] - uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = m_loss + r_loss - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = ( - (loss_grad**2) - .mean(dim=mean_dims, keepdim=True) - .sqrt() - .clamp(min=1.0e-20) - ) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) - except Exception as e: - logging.info( - f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." - ) - - return x_grad, None, None, None, None, None, None - - -class Balancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, - ): - super().__init__() - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.min_abs = min_abs - self.max_abs = max_abs - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or not x.requires_grad - or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) - ): - return _no_op(x) - - prob = float(self.prob) - if random.random() < prob: - # The following inner-functions convert from the way we historically specified - # these limitations, as limits on the absolute value and the proportion of positive - # values, to limits on the RMS value and the (mean / stddev). - def _abs_to_rms(x): - # for normally distributed data, if the expected absolute value is x, the - # expected rms value will be sqrt(pi/2) * x. - return 1.25331413732 * x - - def _proportion_positive_to_mean(x): - def _atanh(x): - eps = 1.0e-10 - # eps is to prevent crashes if x is exactly 0 or 1. - # we'll just end up returning a fairly large value. - return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 - - def _approx_inverse_erf(x): - # 1 / (sqrt(pi) * ln(2)), - # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions - # this approximation is extremely crude and gets progressively worse for - # x very close to -1 or +1, but we mostly care about the "middle" region - # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, - # and math.erf(0.0407316414078772) = 0.045935330944660666, - # which is pretty close to 0.05. - return 0.8139535143 * _atanh(x) - - # first convert x from the range 0..1 to the range -1..1 which the error - # function returns - x = -1 + (2 * x) - return _approx_inverse_erf(x) - - min_mean = _proportion_positive_to_mean(float(self.min_positive)) - max_mean = _proportion_positive_to_mean(float(self.max_positive)) - min_rms = _abs_to_rms(float(self.min_abs)) - max_rms = _abs_to_rms(float(self.max_abs)) - grad_scale = float(self.grad_scale) - - assert x.shape[self.channel_dim] == self.num_channels - - return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt( - x: Tensor, limit: float, penalty: float, name: str = None -) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - - The name is for randomly printed debug info. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss, name) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: - ctx.save_for_backward(x) - ctx.module = module - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - w = ctx.module - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, w.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" - ) - - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = w.grad_scale * ( - x_grad.to(torch.float32).norm() - / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None - except Exception as e: - logging.info( - f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." - ) - return x_grad, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float, float]], - grad_scale: FloatLike, - ): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert float(whitening_limit) >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - self.grad_scale = grad_scale - - if isinstance(prob, float): - prob = (prob, prob) - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob <= self.max_prob <= 1 - self.prob = self.max_prob - self.name = None # will be set in training loop - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: - return _no_op(x) - else: - return WhiteningPenaltyFunction.apply(x, self) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.y_shape = y.shape - if random.random() < 0.002 and name is not None: - loss_sum = y.sum().item() - logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ( - ans_grad, - torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), - None, - ) - - -def with_loss(x, y, name): - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y, name) - - -class ScaleGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha - return x - - @staticmethod - def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None - - -def scale_grad(x: Tensor, alpha: float): - return ScaleGradFunction.apply(x, alpha) - - -class ScaleGrad(nn.Module): - def __init__(self, alpha: float): - super().__init__() - self.alpha = alpha - - def forward(self, x: Tensor) -> Tensor: - return scale_grad(x, self.alpha) - - -class LimitParamValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, min: float, max: float): - ctx.save_for_backward(x) - assert max >= min - ctx.min = min - ctx.max = max - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x,) = ctx.saved_tensors - # where x < ctx.min, ensure all grads are negative (this will tend to make - # x more positive). - x_grad = x_grad * torch.where( - torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 - ) - # where x > ctx.max, ensure all grads are positive (this will tend to make - # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) - return x_grad, None, None - - -def limit_param_value( - x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True -): - # You apply this to (typically) an nn.Parameter during training to ensure that its - # (elements mostly) stays within a supplied range. This is done by modifying the - # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: - return LimitParamValue.apply(x, min, max) - else: - return x - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - x_dtype = x.dtype - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.044 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. -class Dropout2(nn.Module): - def __init__(self, p: FloatLike): - super().__init__() - self.p = p - - def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) - - -class MulForDropout3(torch.autograd.Function): - # returns (x * y * alpha) where alpha is a float and y doesn't require - # grad and is zero-or-one. - @staticmethod - @custom_fwd - def forward(ctx, x, y, alpha): - assert not y.requires_grad - ans = x * y * alpha - ctx.save_for_backward(ans) - ctx.alpha = alpha - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - (ans,) = ctx.saved_tensors - x_grad = ctx.alpha * ans_grad * (ans != 0) - return x_grad, None, None - - -# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, -# and it lets you choose one dimension to share the dropout mask over -class Dropout3(nn.Module): - def __init__(self, p: FloatLike, shared_dim: int): - super().__init__() - self.p = p - self.shared_dim = shared_dim - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - scale = 1.0 / (1 - p) - rand_shape = list(x.shape) - rand_shape[self.shared_dim] = 1 - mask = torch.rand(*rand_shape, device=x.device) > p - ans = MulForDropout3.apply(x, mask, scale) - return ans - - -class SwooshLFunction(torch.autograd.Function): - """ - swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - x_dtype = x.dtype - - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - coeff = -0.08 - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 - - if not requires_grad: - return y - y.backward(gradient=torch.ones_like(y)) - - grad = x.grad - floor = coeff - ceil = 1.0 + coeff + 0.005 - - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - - coeff = -0.08 - floor = coeff - ceil = 1.0 + coeff + 0.005 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class SwooshL(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation.""" - if torch.jit.is_scripting(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 - if not x.requires_grad: - return k2.swoosh_l_forward(x) - else: - return k2.swoosh_l(x) - # return SwooshLFunction.apply(x) - - -class SwooshRFunction(torch.autograd.Function): - """ - swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - - derivatives are between -0.08 and 0.92. - - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - x_dtype = x.dtype - - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 - - if not requires_grad: - return y - y.backward(gradient=torch.ones_like(y)) - - grad = x.grad - floor = -0.08 - ceil = 0.925 - - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class SwooshR(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation.""" - if torch.jit.is_scripting(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 - if not x.requires_grad: - return k2.swoosh_r_forward(x) - else: - return k2.swoosh_r(x) - # return SwooshRFunction.apply(x) - - -# simple version of SwooshL that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshLForward(x: Tensor): - x_offset = x - 4.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.035 - - -# simple version of SwooshR that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshRForward(x: Tensor): - x_offset = x - 1.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.313261687 - - -class ActivationDropoutAndLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int], - ): - if dropout_p != 0.0: - dropout_shape = list(x.shape) - if dropout_shared_dim is not None: - dropout_shape[dropout_shared_dim] = 1 - # else it won't be very memory efficient. - dropout_mask = (1.0 / (1.0 - dropout_p)) * ( - torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p - ) - else: - dropout_mask = None - - ctx.save_for_backward(x, weight, bias, dropout_mask) - - ctx.activation = activation - - forward_activation_dict = { - "SwooshL": k2.swoosh_l_forward, - "SwooshR": k2.swoosh_r_forward, - } - # it will raise a KeyError if this fails. This will be an error. We let it - # propagate to the user. - activation_func = forward_activation_dict[activation] - x = activation_func(x) - if dropout_mask is not None: - x = x * dropout_mask - x = torch.nn.functional.linear(x, weight, bias) - return x - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor): - saved = ctx.saved_tensors - (x, weight, bias, dropout_mask) = saved - - forward_and_deriv_activation_dict = { - "SwooshL": k2.swoosh_l_forward_and_deriv, - "SwooshR": k2.swoosh_r_forward_and_deriv, - } - # the following lines a KeyError if the activation is unrecognized. - # This will be an error. We let it propagate to the user. - func = forward_and_deriv_activation_dict[ctx.activation] - - y, func_deriv = func(x) - if dropout_mask is not None: - y = y * dropout_mask - # now compute derivative of y w.r.t. weight and bias.. - # y: (..., in_channels), ans_grad: (..., out_channels), - (out_channels, in_channels) = weight.shape - - in_channels = y.shape[-1] - g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) - y_deriv = torch.matmul(ans_grad, weight) - bias_deriv = None if bias is None else g.sum(dim=0) - x_deriv = y_deriv * func_deriv - if dropout_mask is not None: - # order versus func_deriv does not matter - x_deriv = x_deriv * dropout_mask - - return x_deriv, weight_deriv, bias_deriv, None, None, None - - -class ActivationDropoutAndLinear(torch.nn.Module): - """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = "SwooshL", - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0, - ): - super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from - layer = ScaledLinear( - in_channels, out_channels, bias=bias, initial_scale=initial_scale - ) - - self.weight = layer.weight - # register_parameter properly handles making it a parameter when l.bias - # is None. I think there is some reason for doing it this way rather - # than just setting it to None but I don't know what it is, maybe - # something to do with exporting the module.. - self.register_parameter("bias", layer.bias) - - self.activation = activation - self.dropout_p = dropout_p - self.dropout_shared_dim = dropout_shared_dim - - def forward(self, x: Tensor): - if torch.jit.is_scripting(): - if self.activation == "SwooshL": - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) - else: - assert False, self.activation - return torch.nn.functional.linear(x, self.weight, self.bias) - - return ActivationDropoutAndLinearFunction.apply( - x, - self.weight, - self.bias, - self.activation, - float(self.dropout_p), - self.dropout_shared_dim, - ) - - -class ClipGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, limit: float): - ctx.limit = limit - return x - - @staticmethod - def backward(ctx, x_grad, *args): - return x_grad.clamp(-ctx.limit, ctx.limit), None - - -def clip_grad(x: Tensor, limit: float): - return ClipGradFunction.apply(x, limit) - - -class AbsValuePenalizer(nn.Module): - """ - This module adds a penalty to the loss function when ever the absolute value of - any element of the input tensor exceeds a certain limit. - """ - - def __init__(self, limit: float, prob: float = 0.1, penalty: float = 1.0e-04): - super().__init__() - self.limit = limit - self.penalty = penalty - - self.prob = prob - self.name = None # will be set in training loop - - # 20% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.2) - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or not x.requires_grad - or not self.training - or random.random() > self.prob - ): - # or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) - return _no_op(x) # the _no_op op is to make our diagnostics code work. - - x = penalize_abs_values_gt( - x, limit=self.limit, penalty=self.penalty, name=self.name - ) - return x - - -def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: - if num_channels <= x.shape[-1]: - return x[..., :num_channels] - else: - shape = list(x.shape) - shape[-1] = num_channels - shape[-1] - zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device) - return torch.cat((x, zeros), dim=-1) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = Balancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - min_abs=0.0, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_sign: x = ", x) - print("_test_balancer_sign: y grad = ", y_grad) - print("_test_balancer_sign: x grad = ", x.grad) - - -def _test_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = Balancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - min_abs=0.2, - max_abs=0.7, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_magnitude: x = ", x) - print("_test_balancer_magnitude: y grad = ", y_grad) - print("_test_balancer_magnitude: x grad = ", x.grad) - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = (1.2 - (-0.043637)) / 255.0 - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshl_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshL() - - tol = 1.0 / 255.0 - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshr_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshR() - - tol = 1.0 / 255.0 - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_piecewise_linear(): - p = PiecewiseLinear((0, 10.0)) - for x in [-100, 0, 100]: - assert p(x) == 10.0 - p = PiecewiseLinear((0, 10.0), (1, 0.0)) - for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: - print("x, y = ", x, y) - assert p(x) == y, (x, p(x), y) - - q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] - pq = p.max(q) - for x in x_vals: - y1 = max(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p.min(q) - for x in x_vals: - y1 = min(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p + q - for x in x_vals: - y1 = p(x) + q(x) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - - -def _test_activation_dropout_and_linear(): - in_channels = 20 - out_channels = 30 - - for bias in [True, False]: - # actually we don't test for dropout_p != 0.0 because forward functions will give - # different answers. This is because we are using the k2 implementation of - # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() - # internally, messing up the random state. - for dropout_p in [0.0]: - for activation in ["SwooshL", "SwooshR"]: - m1 = nn.Sequential( - SwooshL() if activation == "SwooshL" else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear( - in_channels, out_channels, bias=bias, initial_scale=0.5 - ), - ) - m2 = ActivationDropoutAndLinear( - in_channels, - out_channels, - bias=bias, - initial_scale=0.5, - activation=activation, - dropout_p=dropout_p, - ) - with torch.no_grad(): - m2.weight[:] = m1[2].weight - if bias: - m2.bias[:] = m1[2].bias - # make sure forward gives same result. - x1 = torch.randn(10, in_channels) - x1.requires_grad = True - - # TEMP. - assert torch.allclose( - SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 - ) - - x2 = x1.clone().detach() - x2.requires_grad = True - seed = 10 - torch.manual_seed(seed) - y1 = m1(x1) - y_grad = torch.randn_like(y1) - y1.backward(gradient=y_grad) - torch.manual_seed(seed) - y2 = m2(x2) - y2.backward(gradient=y_grad) - - print( - f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" - ) - print("y1 = ", y1) - print("y2 = ", y2) - assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) - if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) - print("x1.grad = ", x1.grad) - print("x2.grad = ", x2.grad) - - def isclose(a, b): - # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ( - (a**2).sum() * (b**2).sum() - ).sqrt() - - # the SwooshL() implementation has a noisy gradient due to 1-byte - # storage of it. - assert isclose(x1.grad, x2.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_piecewise_linear() - _test_softmax() - _test_whiten() - _test_balancer_sign() - _test_balancer_magnitude() - _test_swooshl_deriv() - _test_swooshr_deriv() - _test_activation_dropout_and_linear() - _test_double_swish_deriv() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py deleted file mode 100644 index 7acbc1808..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py +++ /dev/null @@ -1,276 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import Tuple - -import torch -from scaling import ( - Balancer, - BiasNorm, - Dropout3, - FloatLike, - Optional, - ScaledConv2d, - ScaleGrad, - ScheduledFloat, - SwooshL, - SwooshR, - Whiten, -) -from torch import Tensor, nn - - -class ConvNeXt(nn.Module): - """ - Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf - """ - - def __init__( - self, - channels: int, - hidden_ratio: int = 3, - kernel_size: Tuple[int, int] = (7, 7), - layerdrop_rate: FloatLike = None, - ): - super().__init__() - padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - hidden_channels = channels * hidden_ratio - if layerdrop_rate is None: - layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) - self.layerdrop_rate = layerdrop_rate - - self.depthwise_conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=padding, - ) - - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1 - ) - - self.hidden_balancer = Balancer( - hidden_channels, - channel_dim=1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - self.activation = SwooshL() - self.pointwise_conv2 = ScaledConv2d( - in_channels=hidden_channels, - out_channels=channels, - kernel_size=1, - initial_scale=0.01, - ) - - self.out_balancer = Balancer( - channels, - channel_dim=1, - min_positive=0.4, - max_positive=0.6, - min_abs=1.0, - max_abs=6.0, - ) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: - return self.forward_internal(x) - layerdrop_rate = float(self.layerdrop_rate) - - if layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = ( - torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) - > layerdrop_rate - ) - else: - mask = None - # turns out this caching idea does not work with --world-size > 1 - # return caching_eval(self.forward_internal, x, mask) - return self.forward_internal(x, mask) - - def forward_internal( - self, x: Tensor, layer_skip_mask: Optional[Tensor] = None - ) -> Tensor: - """ - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - - The returned value has the same shape as x. - """ - bypass = x - x = self.depthwise_conv(x) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - if layer_skip_mask is not None: - x = x * layer_skip_mask - - x = bypass + x - x = self.out_balancer(x) - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) - - return x - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/2 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: FloatLike = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-3)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - bottleneck: - bottleneck dimension for 1d squeeze-excite - """ - assert in_channels >= 7 - super().__init__() - - # The ScaleGrad module is there to prevent the gradients - # w.r.t. the weight or bias of the first Conv2d module in self.conv from - # exceeding the range of fp16 when using automatic mixed precision (amp) - # training. (The second one is necessary to stop its bias from getting - # a too-large gradient). - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ScaleGrad(0.2), - Balancer(layer1_channels, channel_dim=1, max_abs=1.0), - SwooshR(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - Balancer(layer2_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - Balancer(layer3_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - ) - - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) - - out_width = (((in_channels - 1) // 2) - 1) // 2 - - self.out = nn.Linear(out_width * layer3_channels, out_channels) - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), - prob=(0.025, 0.25), - grad_scale=0.02, - ) - - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. - self.out_norm = BiasNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) - - def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - - output lengths, of shape (batch_size,) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) - # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite - # gradients. - x = self.conv(x) - x = self.convnext(x) - - # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() - - return x, x_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py b/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py deleted file mode 100755 index 13483637d..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless4/test_model.py -""" - -from scaling import ScheduledFloat -from train_subformer import get_params, get_text_encoder, get_transducer_model -from zipformer import Zipformer2 - - -def test_model_1(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = 24 - params.dim_feedforward = 1536 # 384 * 4 - params.encoder_dim = 384 - model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf -def test_model_M(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = "2,4,3,2,4" - params.feedforward_dims = "1024,1024,2048,2048,1024" - params.nhead = "8,8,8,8,8" - params.encoder_dims = "384,384,384,384,384" - params.attention_dims = "192,192,192,192,192" - params.encoder_unmasked_dims = "256,256,256,256,256" - params.zipformer_downsampling_factors = "1,2,4,8,2" - params.cnn_module_kernels = "31,31,15,15" - - params.text_encoder_dim = (192, 192, 256, 384) - params.decoder_dim = 512 - params.joiner_dim = 512 - model = Zipformer2( - output_downsampling_factor=8, - downsampling_factor=(1, 2, 4, 8), - num_encoder_layers=(2, 4, 4, 4), - encoder_dim=(192, 192, 256, 384), - encoder_unmasked_dim=(192, 192, 256, 256), - query_head_dim=(32, 32, 32, 32), - pos_head_dim=(4, 4, 4, 4), - value_head_dim=(12, 12, 12, 12), - pos_dim=48, - num_heads=(4, 4, 4, 8), - feedforward_dim=( - 384, - 512, - 768, - 1024, - ), # could increase this if there is nough data - cnn_module_kernel=(31, 31, 15, 15), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=False, - ) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - model = Zipformer2( - output_downsampling_factor=8, - downsampling_factor=(1, 2, 4, 8), - num_encoder_layers=(2, 4, 6, 6), - encoder_dim=(256, 256, 384, 512), - encoder_unmasked_dim=(196, 196, 256, 256), - query_head_dim=(32, 32, 32, 32), - pos_head_dim=(4, 4, 4, 4), - value_head_dim=(12, 12, 12, 12), - pos_dim=48, - num_heads=(4, 4, 4, 8), - feedforward_dim=( - 384, - 512, - 768, - 1024, - ), # could increase this if there is nough data - cnn_module_kernel=(31, 31, 15, 15), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=False, - ) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -def main(): - # test_model_1() - test_model_M() - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py deleted file mode 100644 index efb4acc3c..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import List - - -def train_text_normalization(s: str) -> str: - # replace full-width with half-width - s = s.replace("“", '"') - s = s.replace("”", '"') - s = s.replace("‘", "'") - s = s.replace("’", "'") - if s[:2] == '" ': # remove the starting double quote - s = s[2:] - - return s - - -def ref_text_normalization(ref_text: str) -> str: - # Rule 1: Remove the [FN#[]] - p = r"[FN#[0-9]*]" - pattern = re.compile(p) - - res = pattern.findall(ref_text) - ref_text = re.sub(p, "", ref_text) - - ref_text = train_text_normalization(ref_text) - - return ref_text - - -def remove_non_alphabetic(text: str, strict: bool = True) -> str: - # Recommend to set strict to False - if not strict: - # Note, this also keeps space, single quote(') and hypen (-) - text = text.replace("-", " ") - text = text.replace("—", " ") - return re.sub(r"[^a-zA-Z0-9\s']+", "", text) - else: - # only keeps space - return re.sub(r"[^a-zA-Z\s]+", "", text) - - -def upper_only_alpha(text: str) -> str: - return remove_non_alphabetic(text.upper(), strict=False) - - -def lower_only_alpha(text: str) -> str: - return remove_non_alphabetic(text.lower(), strict=False) - - -def lower_all_char(text: str) -> str: - return text.lower() - - -def upper_all_char(text: str) -> str: - return text.upper() - - -def _apply_style_transform(text: List[str], transform: str) -> List[str]: - """Apply transform to a list of text. By default, the text are in - ground truth format, i.e mixed-punc. - - Args: - text (List[str]): Input text string - transform (str): Transform to be applied - - Returns: - List[str]: _description_ - """ - if transform == "mixed-punc": - return text - elif transform == "upper-no-punc": - return [upper_only_alpha(s) for s in text] - elif transform == "lower-no-punc": - return [lower_only_alpha(s) for s in text] - elif transform == "lower-punc": - return [lower_all_char(s) for s in text] - else: - raise NotImplementedError(f"Unseen transform: {transform}") - - -if __name__ == "__main__": - ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." - print(ref_text) - res = upper_only_alpha(ref_text) - print(res) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py deleted file mode 100644 index 93f7e1248..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ /dev/null @@ -1,1418 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - - -# For mix precision training, using MCP style transcript: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./zipformer_prompt_asr/train_baseline.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer_prompt_asr/exp \ - --transcript-style MCP \ - --max-duration 1000 - -# For mix precision training, using UC style transcript: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./zipformer_prompt_asr/train_baseline.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer_prompt_asr/exp \ - --transcript-style UC \ - --max-duration 1000 - -# To train a streaming model - -./zipformer_prompt_asr/train_baseline.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --causal 1 - --exp-dir zipformer/exp \ - --max-duration 1000 - -""" - - -import argparse -import copy -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriHeavyAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model_baseline import Transducer -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from text_normalization import train_text_normalization, upper_only_alpha -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_mixed_cased_with_punc( - texts: List[str], - pre_texts: List[str], - context_list: Optional[str] = None, - rare_word_list: Optional[List[str]] = None, -) -> str: - # Always get the first one, which is the mixed-cased text with punc - out = {"text": texts[0], "pre_text": pre_texts[0]} - return out - - -def get_upper_only_alpha( - texts: List[str], - pre_texts: List[str], - context_list: Optional[str] = None, - rare_word_list: Optional[List[str]] = None, -) -> str: - # Always get the first one, which is the mixed-cased text with punc, - # but with upper case it and remove punctuation - out = { - "text": upper_only_alpha(texts[0]), - "pre_text": upper_only_alpha(pre_texts[0]), - } - return out - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--text-encoder-dim", - type=str, - default="256,256,384,512", - help="Embedding dimension in text encoder stacks: a comma-separated list of 4 elements, " - "or you should change other configs in the code.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--transcript-style", - type=str, - default="UC", - choices=["UC", "MCP"], - help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations, - MCP stands for mix-cased text with punctuation. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - texts = [train_text_normalization(t) for t in texts] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - if random.random() < 0.02: - logging.info(f"Ref texts: {texts[0]}") - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - libriheavy = LibriHeavyAsrDataModule(args) - - train_cuts = libriheavy.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 30.0: - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].texts[0], out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].texts[0]}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - if params.transcript_style == "UC": - text_sampling_func = get_upper_only_alpha - else: - text_sampling_func = get_mixed_cased_with_punc - logging.info(f"Using {params.transcript_style} style for training.") - logging.info(f"Text sampling func: {text_sampling_func}") - train_dl = libriheavy.train_dataloaders( - train_cuts, - sampler_state_dict=sampler_state_dict, - text_sampling_func=text_sampling_func, - ) - - valid_cuts = libriheavy.dev_cuts() - valid_dl = libriheavy.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriHeavyAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py deleted file mode 100755 index 2a2c206aa..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ /dev/null @@ -1,1797 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, -# -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For mix precision training: - -(1) Non-streaming model, **without** context list - -./zipformer_prompt_asr/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --subset medium \ - --causal False \ - --exp-dir zipformer_prompt_asr/exp \ - --max-duration 1000 \ - --memory-layer 0 \ - --text-encoder-type BERT \ - --text-encoder-dim 768 \ - --use-style-prompt True \ - --use-context-list False - -(2) Non-streaming model, **with** context list - -./zipformer_prompt_asr/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --subset medium \ - --causal False \ - --exp-dir zipformer_prompt_asr/exp \ - --max-duration 1000 \ - --memory-layer 0 \ - --text-encoder-type BERT \ - --text-encoder-dim 768 \ - --use-style-prompt True \ - --use-context-list True \ - --top-k 10000 \ - --rare-word-file data/context_biasing/small_rare_words_topk_10000.txt - - -""" - - -import argparse -import copy -import logging -import os -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import k2 -import numpy -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriHeavyAsrDataModule -from dataset import ( - naive_triplet_text_sampling, - random_shuffle_subset, - triplet_text_sampling, - triplet_text_sampling_with_context_list, -) -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model_with_BERT import PromptedTransducer -from optim import Eden, ScaledAdam -from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR -from subsampling import Conv2dSubsampling -from text_normalization import ( - lower_all_char, - lower_only_alpha, - train_text_normalization, - upper_all_char, - upper_only_alpha, -) -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - -style_transforms = [ - lambda x: x, # return it self - upper_only_alpha, - lower_only_alpha, - lower_all_char, -] - - -def get_first(texts: List[str], pre_texts: List[str]) -> str: - out = { - "text": texts[0], - "pre_text": pre_texts[0], - "style_text": "", - "transform_ids": 0, - } - return out - - -def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str: - # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha - out = { - "text": upper_only_alpha(texts[0]), - "pre_text": upper_only_alpha(pre_texts[0]), - "style_text": "", - "transform_ids": 0, - } - return out - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--memory-dropout-rate", - type=float, - default=0.05, - help="By which probability, dropout the memory when doing cross-attention.", - ) - - parser.add_argument( - "--memory-layer", - type=int, - default=0, - help="Start doing cross-attention from which layer. Zero-indexed", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--freeze-text-encoder", - type=str2bool, - default=True, - ) - - parser.add_argument( - "--text-encoder-type", - type=str, - default="BERT", - choices=["BERT", "DistilBERT"], - help="Type of the text encoder", - ) - - parser.add_argument( - "--text-encoder-dim", - type=int, - default=768, - help="Dimension of the text encoder", - ) - - parser.add_argument( - "--text-encoder-adapter", - type=str2bool, - default=False, - help="An adapter for pre-trained BERT", - ) - - parser.add_argument( - "--context-injection", - type=str2bool, - default=False, - help="Inject context embedding into the joiner", - ) - - parser.add_argument( - "--context-dropout-rate", - type=float, - default=0.05, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--use-style-prompt", - type=str2bool, - default=True, - help="Whether to use style prompt.", - ) - - # arguments for using prompt - parser.add_argument( - "--pre-text-shuffle-prob", - type=float, - default=0.05, - help="The proportion of pre_text to be shuffled with in a batch", - ) - - parser.add_argument( - "--style-text-shuffle-prob", - type=float, - default=0.2, - help="The proportion of style_text to be shuffled with in a batch", - ) - - parser.add_argument( - "--prompt-mask-prob", - type=float, - default=0.05, - help="The probability of masking prompts", - ) - - parser.add_argument( - "--forced-upper-pre-text", - type=str2bool, - default=False, - help="Forced format of pre-text", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -class TextEmbedding(nn.Module): - def __init__( - self, - num_embeddings: int = 256, - embedding_dim: int = 256, - kernel_size: int = 3, - layer1_channels: int = 256, - layer2_channels: int = 256, - bias: bool = True, - dropout: float = 0.1, - ): - super().__init__() - self.embed = nn.Embedding( - num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes - embedding_dim=embedding_dim, # - ) - - assert embedding_dim == layer1_channels # for depth wise convolution - self.conv = nn.Sequential( - nn.Conv1d( - embedding_dim, - layer1_channels, # depthwise convolution - kernel_size=kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=layer1_channels, - bias=True, - ), - ScaleGrad(0.2), - Balancer(layer1_channels, channel_dim=1, min_positive=0.1, max_abs=1.0), - nn.ReLU(), - nn.Conv1d( - layer1_channels, - layer2_channels, - kernel_size=1, # pointwise convolution - stride=1, - padding=0, - bias=True, - ), - Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0), - nn.ReLU(), - ) - - self.out_norm = BiasNorm(layer2_channels) - self.dropout = Dropout3(dropout, shared_dim=1) - - def forward(self, text: torch.Tensor) -> torch.Tensor: - """Forward function of the text embedding - - Args: - text (torch.Tensor): Text in UTF-8 bytes (T,N) - Returns: - The embeddings of text (T,N,C) - """ - text = self.embed(text) # (T,N,C) - - # src = text - text = text.permute(1, 2, 0) # (T,N,C) -> (N,C,T) - text = self.conv(text) - text = text.permute(2, 0, 1) # (N,C,T) -> (T,N,C) - # src = src + text - - text = self.out_norm(text) - text = self.dropout(text) - - return text - - -def get_text_encoder(params: AttributeDict) -> nn.Module: - # Return a text encoder - if params.text_encoder_type == "BERT": # This is a BERT-base-cased - from transformers import BertModel - - logging.info("Loading pre-trained BERT-base-cased as text encoder") - if os.path.exists("data/models/bert-base-cased"): - model = BertModel.from_pretrained("data/models/bert-base-cased") - else: - model = BertModel.from_pretrained("bert-base-cased") - assert params.text_encoder_dim == 768 - elif params.text_encoder_type == "BERT-large": - from transformers import BertModel - - logging.info("Loading pre-trained BERT-large-uncased as text encoder") - if os.path.exists("data/models/bert-large-uncased"): - model = BertModel.from_pretrained("data/models/bert-large-uncased") - else: - model = BertModel.from_pretrained("bert-large-uncased") - assert params.text_encoder_dim == 1024 - elif params.text_encoder_type == "DistilBERT": - from transformers import DistilBertModel # This is a DistilBERT-base-cased - - logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") - model = DistilBertModel.from_pretrained("distilbert-base-cased") - assert params.text_encoder_dim == 768 - else: - raise ValueError() - - return model - - -def get_tokenizer(params: AttributeDict): - - if params.text_encoder_type == "BERT": - from transformers import BertTokenizer - - # This is a BERT-base-cased - if os.path.exists("data/models/bert-base-cased"): - tokenizer = BertTokenizer.from_pretrained("data/models/bert-base-cased") - else: - tokenizer = BertTokenizer.from_pretrained("bert-base-cased") - elif params.text_encoder_type == "BERT-large": - from transformers import BertTokenizer - - # This is a BERT-large-uncased - if os.path.exists("data/models/bert-large-uncased"): - tokenizer = BertTokenizer.from_pretrained("data/models/bert-large-uncased") - else: - tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") - elif params.text_encoder_type == "DistilBERT": - from transformers import DistilBertTokenizer - - tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased") - else: - raise ValueError() - - return tokenizer - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - memory_dim=params.text_encoder_dim, # This is fixed as the BERT base model is 768-D - memory_layer=params.memory_layer, - memory_dropout_rate=params.memory_dropout_rate, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - context_dim=( - 4 * 768 if params.context_injection else -1 - ), # the output dim of text encoder - context_injection=params.context_injection, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - text_encoder = get_text_encoder(params) # This should be a cased BERT base model - num_param = sum([p.numel() for p in text_encoder.parameters()]) - logging.info(f"Num params in text encoder: {num_param}") - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = PromptedTransducer( - encoder_embed=encoder_embed, - encoder=encoder, - text_encoder=text_encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - text_encoder_type=params.text_encoder_type, - text_encoder_adapter=params.text_encoder_adapter, - freeze_text_encoder=params.freeze_text_encoder, - context_fuser=None, - ) - - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def _encode_texts_as_bytes_with_tokenizer( - pre_texts: List[str], - style_texts: List[str], - tokenizer, - device: torch.device, - max_len: int = 500, - no_limit: bool = False, -) -> Tuple[Dict, Tensor]: - """ - Encode texts as bytes and then integer tensors. - Note that the style text will be added to the beginning of texts. - """ - batch_size = len(pre_texts) - max_len = min(max_len, 500) - - if no_limit: - allowed_lens = [5000 - len(s) for s in style_texts] - else: - allowed_lens = [1000 - len(s) for s in style_texts] - truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)] - combined_text = [ - style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size) - ] - - encoded_style_texts = tokenizer( - style_texts, - return_tensors="pt", - padding=True, - truncation=True, - return_length=True, - max_length=max_len, - ) - style_lens = encoded_style_texts["length"].to(device) - - # Use tokenizer to prepare input for text encoder - encoded_inputs = tokenizer( - combined_text, - return_tensors="pt", - padding=True, - truncation=True, - return_length=True, - max_length=max_len, - ).to(device) - - return encoded_inputs, style_lens - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - batch_size = feature.size(0) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - pre_texts = batch["supervisions"]["pre_text"] - style_texts = batch["supervisions"][ - "style_text" - ] # the style texts are in gt format - transform_ids = batch["supervisions"]["transform_ids"] - - # This is to replace full-width symbols with half-width symbols - texts = [train_text_normalization(t) for t in texts] - pre_texts = [train_text_normalization(t) for t in pre_texts] - style_texts = [train_text_normalization(t) for t in style_texts] - - y = sp.encode( - texts, out_type=int - ) # sp.encode treats consecutive space as a single space - y = k2.RaggedTensor(y).to(device) - - if params.forced_upper_pre_text: - pre_texts = [upper_only_alpha(p) for p in pre_texts] - - # only shuffle the pre_text and style texts if during training, and use style prompt - if is_training: - # randomly shuffle&mask the pre_text - pre_texts = random_shuffle_subset( - pre_texts, - p=params.pre_text_shuffle_prob, - p_mask=params.prompt_mask_prob, - ) - - if params.use_style_prompt: - if random.random() < 0.5: - # randomly shuffle the style_text - # now the style_texts are all in gt format - style_texts = random_shuffle_subset( - style_texts, - p=params.style_text_shuffle_prob, - p_mask=params.prompt_mask_prob, - ) - - assert len(transform_ids) == len(style_texts) - - for i in range(len(style_texts)): - t = transform_ids[i] # get the transform id - style_texts[i] = style_transforms[t](style_texts[i]) - - if not params.use_style_prompt: - style_texts = [ - "" for _ in style_texts - ] # use empty string for style texts if don't use style prompt - - if random.random() < 0.05: - logging.info(f"Pre texts: {pre_texts[0]}") - logging.info(f"Ref texts: {texts[0]}") - logging.info(f"Style texts: {style_texts[0]}") - - encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( - pre_texts=pre_texts, - style_texts=style_texts, - tokenizer=tokenizer, - device=device, - ) - - if random.random() < 0.02: - logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ") - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - encoded_inputs=encoded_inputs, - style_lens=style_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if not params.use_style_prompt: - if params.pre_text_shuffle_prob == 0.0: - logging.info( - f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}" - ) - logging.info( - "If style prompt is not used, you should be careful when shuffling the pre_text within the same batch" - ) - logging.info("Hard set this probability to 0.0!") - params.pre_text_shuffle_prob = 0.0 - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - tokenizer = get_tokenizer(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - if params.freeze_text_encoder: - freeze_modules = ["text_encoder"] - logging.info( - "Freeze the parameters of text encoder and don't include them in the optimizer" - ) - else: - freeze_modules = [] - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs( - model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules - ), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - args.max_duration = 100 - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - libriheavy = LibriHeavyAsrDataModule(args) - - train_cuts = libriheavy.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 30.0: - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].texts[0], out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].texts[0]}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - if params.use_context_list: - text_sampling_func = triplet_text_sampling_with_context_list - else: - text_sampling_func = triplet_text_sampling - - logging.info(f"Text sampling: {text_sampling_func}") - - train_dl = libriheavy.train_dataloaders( - train_cuts, - sampler_state_dict=sampler_state_dict, - text_sampling_func=text_sampling_func, - ) - - # For fair comparison, use fixed sampling in valid dataloaders - valid_cuts = libriheavy.dev_cuts() - valid_dl = libriheavy.valid_dataloaders( - valid_cuts, text_sampling_func=naive_triplet_text_sampling - ) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - tokenizer=tokenizer, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - tokenizer=tokenizer, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - tokenizer: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - tokenizer=tokenizer, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriHeavyAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py deleted file mode 100644 index ef0c48e8a..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py +++ /dev/null @@ -1,515 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -python ./zipformer_prompt_asr/transcribe_bert.py \ - --epoch 50 \ - --avg 10 \ - --exp-dir ./zipformer_prompt_asr/exp \ - --manifest-dir data/long_audios/long_audio.jsonl.gz \ - --pre-text-transform mixed-punc \ - --style-text-transform mixed-punc \ - --num-history 5 \ - --use-pre-text True \ - --use-gt-pre-text False - - -""" - -import argparse -import logging -import math -import warnings -from pathlib import Path -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from decode_bert import _apply_style_transform -from lhotse import Fbank, load_manifest -from text_normalization import ( - lower_all_char, - lower_only_alpha, - ref_text_normalization, - remove_non_alphabetic, - train_text_normalization, - upper_all_char, - upper_only_alpha, -) -from tqdm import tqdm -from train_bert_encoder import ( - _encode_texts_as_bytes_with_tokenizer, - add_model_arguments, - get_params, - get_tokenizer, - get_transducer_model, -) - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - ) - - parser.add_argument( - "--manifest-dir", - type=str, - default="data/long_audios/long_audio.jsonl.gz", - help="""This is the manfiest for long audio transcription. - The cust are intended to be sorted, i.e first sort by recording ID and - then sort by start timestamp""", - ) - - parser.add_argument( - "--use-pre-text", - type=str2bool, - default=False, - help="Whether use pre-text when decoding the current chunk", - ) - - parser.add_argument( - "--use-style-prompt", - type=str2bool, - default=True, - help="Use style prompt when evaluation", - ) - - parser.add_argument( - "--pre-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of content prompt, i.e pre_text", - ) - - parser.add_argument( - "--style-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], - default="mixed-punc", - help="The style of style prompt, i.e style_text", - ) - - parser.add_argument( - "--num-history", - type=int, - default=2, - help="How many previous chunks to look if using pre-text for decoding", - ) - - parser.add_argument( - "--use-gt-pre-text", - type=str2bool, - default=False, - help="Whether use gt pre text when using content prompt", - ) - - parser.add_argument( - "--post-normalization", - type=str2bool, - default=True, - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - params.res_dir = params.exp_dir / "long_audio_transcribe" - params.res_dir.mkdir(exist_ok=True) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "beam_search" in params.method: - params.suffix += f"-{params.method}-beam-size-{params.beam_size}" - - if params.use_pre_text: - if params.use_gt_pre_text: - params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}" - else: - params.suffix += ( - f"-pre-text-{params.pre_text_transform}-history-{params.num_history}" - ) - - book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "") - setup_logger( - f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info" - ) - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - tokenizer = get_tokenizer(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - # load manifest - manifest = load_manifest(params.manifest_dir) - - results = [] - count = 0 - - last_recording = "" - last_end = -1 - history = [] - num_pre_texts = [] - - for cut in manifest: - if cut.has_features: - feat = cut.load_features() - feat_lens = cut.num_frames - else: - feat = cut.compute_features(extractor=Fbank()) - feat_lens = feat.shape[0] - - cur_recording = cut.recording.id - - if cur_recording != last_recording: - last_recording = cur_recording - history = [] # clean up the history - last_end = -1 - logging.info("Moving on to the next recording") - else: - if cut.start < last_end - 0.2: # overlap with the previous cuts - logging.warning("An overlap exists between current cut and last cut") - logging.warning("Skipping this cut!") - continue - if cut.start > last_end + 10: - logging.warning( - f"Large time gap between the current and previous utterance: {cut.start - last_end}." - ) - - # prepare input - x = torch.tensor(feat, device=device).unsqueeze(0) - x_lens = torch.tensor( - [ - feat_lens, - ], - device=device, - ) - - if params.use_pre_text: - if params.num_history > 0: - pre_texts = history[-params.num_history :] - else: - pre_texts = [] - num_pre_texts.append(len(pre_texts)) - pre_texts = [train_text_normalization(" ".join(pre_texts))] - fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." - style_texts = [fixed_sentence] - - pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) - if params.use_style_prompt: - style_texts = _apply_style_transform( - style_texts, params.style_text_transform - ) - - # encode prompts - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( - pre_texts=pre_texts, - style_texts=style_texts, - tokenizer=tokenizer, - device=device, - no_limit=True, - ) - if params.num_history > 5: - logging.info( - f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} " - ) - - memory, memory_key_padding_mask = model.encode_text( - encoded_inputs=encoded_inputs, - style_lens=style_lens, - ) # (T,B,C) - else: - memory = None - memory_key_padding_mask = None - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - encoder_out, encoder_out_lens = model.encode_audio( - feature=x, - feature_lens=x_lens, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - - if params.method == "greedy_search": - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - hyp = sp.decode(hyp_tokens)[0] # in string format - ref_text = ref_text_normalization( - cut.supervisions[0].texts[0] - ) # required to match the training - - # extend the history - if params.use_gt_pre_text: - history.append(ref_text) - else: - history.append(hyp) - last_end = cut.end # update the last end timestamp - - # append the current decoding result - hyp = hyp.split() - ref = ref_text.split() - results.append((cut.id, ref, hyp)) - - count += 1 - if count % 100 == 0: - logging.info(f"Cuts processed until now: {count}/{len(manifest)}") - logging.info( - f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}" - ) - - logging.info(f"A total of {count} cuts") - logging.info( - f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}" - ) - - results = sorted(results) - recog_path = ( - params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt" - ) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - errs_filename = ( - params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"long-audio-{params.method}", - results, - enable_log=True, - compute_CER=False, - ) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - if params.post_normalization: - params.suffix += "-post-normalization" - - new_res = [] - for item in results: - id, ref, hyp = item - hyp = upper_only_alpha(" ".join(hyp)).split() - ref = upper_only_alpha(" ".join(ref)).split() - new_res.append((id, ref, hyp)) - - new_res = sorted(new_res) - recog_path = ( - params.res_dir - / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" - ) - store_transcripts(filename=recog_path, texts=new_res) - logging.info(f"The transcripts are stored in {recog_path}") - - errs_filename = ( - params.res_dir - / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"long-audio-{params.method}", - new_res, - enable_log=True, - compute_CER=False, - ) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - -if __name__ == "__main__": - main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py deleted file mode 100644 index 533982519..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py +++ /dev/null @@ -1,439 +0,0 @@ -import argparse -import ast -import glob -import logging -import os -from collections import defaultdict -from typing import Dict, Iterable, List, TextIO, Tuple, Union - -import kaldialign -from lhotse import load_manifest, load_manifest_lazy -from lhotse.cut import Cut, CutSet -from text_normalization import remove_non_alphabetic -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--manifest-dir", - type=str, - default="data/fbank", - help="Where are the manifest stored", - ) - - parser.add_argument( - "--subset", type=str, default="medium", help="Which subset to work with" - ) - - parser.add_argument( - "--top-k", - type=int, - default=10000, - help="How many words to keep", - ) - - return parser - - -def get_facebook_biasing_list( - test_set: str, - num_distractors: int = 100, -) -> Dict: - # Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf - assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors - if num_distractors == 0: - if test_set == "test-clean": - biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv" - elif test_set == "test-other": - biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv" - else: - raise ValueError(f"Unseen test set {test_set}") - else: - if test_set == "test-clean": - biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" - elif test_set == "test-other": - biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" - else: - raise ValueError(f"Unseen test set {test_set}") - - f = open(biasing_file, "r") - data = f.readlines() - f.close() - - output = dict() - for line in data: - id, _, l1, l2 = line.split("\t") - if num_distractors > 0: # use distractors - biasing_list = ast.literal_eval(l2) - else: - biasing_list = ast.literal_eval(l1) - biasing_list = [w.strip().upper() for w in biasing_list] - output[id] = " ".join(biasing_list) - - return output - - -def brian_biasing_list(level: str): - # The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf - root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level" - all_files = glob.glob(root_dir + "/*") - biasing_dict = {} - for f in all_files: - k = f.split("/")[-1] - fin = open(f, "r") - data = fin.read().strip().split() - biasing_dict[k] = " ".join(data) - fin.close() - - return biasing_dict - - -def get_rare_words( - subset: str = "medium", - top_k: int = 10000, - # min_count: int = 10000, -): - """Get a list of rare words appearing less than `min_count` times - - Args: - subset: The dataset - top_k (int): How many frequent words - """ - txt_path = f"data/tmp/transcript_words_{subset}.txt" - rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" - - if os.path.exists(rare_word_file): - print("File exists, do not proceed!") - return - - print("---Identifying rare words in the manifest---") - count_file = f"data/tmp/transcript_words_{subset}_count.txt" - if not os.path.exists(count_file): - with open(txt_path, "r") as file: - words = file.read().upper().split() - word_count = {} - for word in words: - word = remove_non_alphabetic(word, strict=False) - word = word.split() - for w in word: - if w not in word_count: - word_count[w] = 1 - else: - word_count[w] += 1 - - word_count = list(word_count.items()) # convert to a list of tuple - word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) - with open(count_file, "w") as fout: - for w, count in word_count: - fout.write(f"{w}\t{count}\n") - - else: - word_count = {} - with open(count_file, "r") as fin: - word_count = fin.read().strip().split("\n") - word_count = [pair.split("\t") for pair in word_count] - word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) - - print(f"A total of {len(word_count)} words appeared!") - rare_words = [] - for word, count in word_count[top_k:]: - rare_words.append(word + "\n") - print(f"A total of {len(rare_words)} are identified as rare words.") - - with open(rare_word_file, "w") as f: - f.writelines(rare_words) - - -def add_context_list_to_manifest( - manifest_dir: str, - subset: str = "medium", - top_k: int = 10000, -): - """Generate a context list of rare words for each utterance in the manifest - - Args: - manifest_dir: Where to store the manifest with context list - subset (str): Subset - top_k (int): How many frequent words - - """ - orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz" - target_manifest_dir = orig_manifest_dir.replace( - ".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz" - ) - if os.path.exists(target_manifest_dir): - print(f"Target file exits at {target_manifest_dir}!") - return - - rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" - print(f"---Reading rare words from {rare_words_file}---") - with open(rare_words_file, "r") as f: - rare_words = f.read() - rare_words = rare_words.split("\n") - rare_words = set(rare_words) - print(f"A total of {len(rare_words)} rare words!") - - cuts = load_manifest_lazy(orig_manifest_dir) - print(f"Loaded manifest from {orig_manifest_dir}") - - def _add_context(c: Cut): - splits = ( - remove_non_alphabetic(c.supervisions[0].texts[0], strict=False) - .upper() - .split() - ) - found = [] - for w in splits: - if w in rare_words: - found.append(w) - c.supervisions[0].context_list = " ".join(found) - return c - - cuts = cuts.map(_add_context) - print(f"---Saving manifest with context list to {target_manifest_dir}---") - cuts.to_file(target_manifest_dir) - print("Finished") - - -def check( - manifest_dir: str, - subset: str = "medium", - top_k: int = 10000, -): - # Show how many samples in the training set have a context list - # and the average length of context list - print("--- Calculating the stats over the manifest ---") - - manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz" - cuts = load_manifest_lazy(manifest_dir) - total_cuts = len(cuts) - has_context_list = [c.supervisions[0].context_list != "" for c in cuts] - context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts] - print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ") - print( - f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}" - ) - - -def write_error_stats( - f: TextIO, - test_set_name: str, - results: List[Tuple[str, str]], - enable_log: bool = True, - compute_CER: bool = False, - biasing_words: List[str] = None, -) -> float: - """Write statistics based on predicted results and reference transcripts. It also calculates the - biasing word error rate as described in https://arxiv.org/pdf/2104.02194.pdf - - It will write the following to the given file: - - - WER - - number of insertions, deletions, substitutions, corrects and total - reference words. For example:: - - Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 - reference words (2337 correct) - - - The difference between the reference transcript and predicted result. - An instance is given below:: - - THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES - - The above example shows that the reference word is `EDISON`, - but it is predicted to `ADDISON` (a substitution error). - - Another example is:: - - FOR THE FIRST DAY (SIR->*) I THINK - - The reference word `SIR` is missing in the predicted - results (a deletion error). - results: - An iterable of tuples. The first element is the cut_id, the second is - the reference transcript and the third element is the predicted result. - enable_log: - If True, also print detailed WER to the console. - Otherwise, it is written only to the given file. - biasing_words: - All the words in the biasing list - Returns: - Return None. - """ - subs: Dict[Tuple[str, str], int] = defaultdict(int) - ins: Dict[str, int] = defaultdict(int) - dels: Dict[str, int] = defaultdict(int) - - # `words` stores counts per word, as follows: - # corr, ref_sub, hyp_sub, ins, dels - words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) - num_corr = 0 - ERR = "*" - - if compute_CER: - for i, res in enumerate(results): - cut_id, ref, hyp = res - ref = list("".join(ref)) - hyp = list("".join(hyp)) - results[i] = (cut_id, ref, hyp) - - for cut_id, ref, hyp in results: - ali = kaldialign.align(ref, hyp, ERR) - for ref_word, hyp_word in ali: - if ref_word == ERR: # INSERTION - ins[hyp_word] += 1 - words[hyp_word][3] += 1 - elif hyp_word == ERR: # DELETION - dels[ref_word] += 1 - words[ref_word][4] += 1 - elif hyp_word != ref_word: # SUBSTITUTION - subs[(ref_word, hyp_word)] += 1 - words[ref_word][1] += 1 - words[hyp_word][2] += 1 - else: - words[ref_word][0] += 1 - num_corr += 1 - ref_len = sum([len(r) for _, r, _ in results]) - sub_errs = sum(subs.values()) - ins_errs = sum(ins.values()) - del_errs = sum(dels.values()) - tot_errs = sub_errs + ins_errs + del_errs - tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) - - if enable_log: - logging.info( - f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " - f"[{tot_errs} / {ref_len}, {ins_errs} ins, " - f"{del_errs} del, {sub_errs} sub ]" - ) - - print(f"%WER = {tot_err_rate}", file=f) - print( - f"Errors: {ins_errs} insertions, {del_errs} deletions, " - f"{sub_errs} substitutions, over {ref_len} reference " - f"words ({num_corr} correct)", - file=f, - ) - print( - "Search below for sections starting with PER-UTT DETAILS:, " - "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", - file=f, - ) - - print("", file=f) - print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) - for cut_id, ref, hyp in results: - ali = kaldialign.align(ref, hyp, ERR) - combine_successive_errors = True - if combine_successive_errors: - ali = [[[x], [y]] for x, y in ali] - for i in range(len(ali) - 1): - if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: - ali[i + 1][0] = ali[i][0] + ali[i + 1][0] - ali[i + 1][1] = ali[i][1] + ali[i + 1][1] - ali[i] = [[], []] - ali = [ - [ - list(filter(lambda a: a != ERR, x)), - list(filter(lambda a: a != ERR, y)), - ] - for x, y in ali - ] - ali = list(filter(lambda x: x != [[], []], ali)) - ali = [ - [ - ERR if x == [] else " ".join(x), - ERR if y == [] else " ".join(y), - ] - for x, y in ali - ] - - print( - f"{cut_id}:\t" - + " ".join( - ( - ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" - for ref_word, hyp_word in ali - ) - ), - file=f, - ) - - print("", file=f) - print("SUBSTITUTIONS: count ref -> hyp", file=f) - - for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): - print(f"{count} {ref} -> {hyp}", file=f) - - print("", file=f) - print("DELETIONS: count ref", file=f) - for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): - print(f"{count} {ref}", file=f) - - print("", file=f) - print("INSERTIONS: count hyp", file=f) - for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): - print(f"{count} {hyp}", file=f) - - unbiased_word_counts = 0 - unbiased_word_errs = 0 - biased_word_counts = 0 - biased_word_errs = 0 - - print("", file=f) - print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) - - for _, word, counts in sorted( - [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True - ): - (corr, ref_sub, hyp_sub, ins, dels) = counts - tot_errs = ref_sub + hyp_sub + ins + dels - # number of appearances of "word" in reference text - ref_count = ( - corr + ref_sub + dels - ) # correct + in ref but got substituted + deleted - # number of appearances of "word" in hyp text - hyp_count = corr + hyp_sub + ins - - if biasing_words is not None: - if word in biasing_words: - biased_word_counts += ref_count - biased_word_errs += ins + dels + ref_sub - else: - unbiased_word_counts += ref_count - unbiased_word_errs += ins + dels + hyp_sub - - print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - - if biasing_words is not None: - B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts) - U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts) - logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ") - logging.info( - f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]" - ) - - return float(tot_err_rate) - - -if __name__ == "__main__": - parser = get_parser() - args = parser.parse_args() - manifest_dir = args.manifest_dir - subset = args.subset - top_k = args.top_k - get_rare_words(subset=subset, top_k=top_k) - add_context_list_to_manifest( - manifest_dir=manifest_dir, - subset=subset, - top_k=top_k, - ) - check( - manifest_dir=manifest_dir, - subset=subset, - top_k=top_k, - ) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py deleted file mode 100644 index d1cf90ffb..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py +++ /dev/null @@ -1,2310 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. -) -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( - ActivationDropoutAndLinear, - Balancer, - BiasNorm, - ChunkCausalDepthwiseConv1d, - Dropout2, - FloatLike, - ScheduledFloat, - Whiten, - convert_num_channels, - limit_param_value, - penalize_abs_values_gt, - softmax, -) -from torch import Tensor, nn - - -class Zipformer2(EncoderInterface): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of - the encoder stacks for purposes of per-frame dropout (recommend 256 for - now). - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - value_head_dim (int or Tuple[int]): dimension of value in each attention head - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. - memory_dim: if supplied and >0, will be the dimension of the memory embeddings - passed into the zipformer (e.g. this might be the output of another - Zipformer used to create embedding vectors.) - memory_dropout_rate: By this probability, do not use the provided memory for - cross-attention. This should give robustness to the model when evaluated - without memory. - memory_layer: if supplied and >0, only add cross-attention module starting from - the specified layer. - """ - - def __init__( - self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - memory_dim: int = -1, - memory_dropout_rate: float = 0.05, - memory_layer: int = 0, - ) -> None: - super(Zipformer2, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - - def _to_tuple(x): - """Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( - encoder_unmasked_dim - ) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames - self.memory_dropout_rate = memory_dropout_rate - self.memory_layer = memory_layer - - for u, d in zip(encoder_unmasked_dim, encoder_dim): - assert u <= d - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder - encoders = [] - - num_encoders = len(downsampling_factor) - for i in range(num_encoders): - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], - memory_dim=memory_dim if i >= self.memory_layer else -1, - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - pos_dim=pos_dim, - dropout=dropout, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim[i], - downsample=downsampling_factor[i], - dropout=dropout, - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - - self.downsample_output = SimpleDownsample( - max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout - ) - - def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (1, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dim) - if not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dim[0] == _encoder_dims0, ( - self.encoder_dim[0], - _encoder_dims0, - ) - - feature_mask_dropout_prob = 0.125 - - # mask1 shape: (1, batch_size, 1) - mask1 = ( - torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob - ).to(x.dtype) - - # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and( - mask1, - ( - torch.rand(1, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype), - ) - - # dim: (1, batch_size, 2) - mask = torch.cat((mask1, mask2), dim=-1) - - feature_masks = [] - for i in range(num_encoders): - channels = self.encoder_dim[i] - feature_mask = torch.ones( - 1, batch_size, channels, dtype=x.dtype, device=x.device - ) - u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - - feature_mask[:, :, u1:u2] *= mask[..., 0:1] - feature_mask[:, :, u2:] *= mask[..., 1:2] - - feature_masks.append(feature_mask) - - return feature_masks - - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.chunk_size) == 1, self.chunk_size - chunk_size = self.chunk_size[0] - else: - chunk_size = random.choice(self.chunk_size) - if chunk_size == -1: - left_context_chunks = -1 - else: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.left_context_frames) == 1, self.left_context_frames - left_context_frames = self.left_context_frames[0] - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - - return chunk_size, left_context_chunks - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - src_key_padding_mask: Optional[torch.Tensor] = None, - memory: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) - memory_key_padding_mask: optionally the mask for padding of memory input (for source- - attention), of shape (batch_size, memory_len); True means - masked position. May be None. - - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - outputs = [] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - feature_masks = [1.0] * len(self.encoder_dim) - else: - feature_masks = self.get_feature_masks(x) - - chunk_size, left_context_chunks = self.get_chunk_info() - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # Not support exporting a model for simulating streaming decoding - attn_mask = None - else: - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - - if self.training and memory is not None: - batch_size = x.shape[1] - # setting memory to zero should be equivalent to not using the - # memory input at all, since the Attention module has no biases. - memory = memory * ( - torch.rand(batch_size, 1, device=memory.device) - > self.memory_dropout_rate - ) - - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x = module( - x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=( - None - if src_key_padding_mask is None - else src_key_padding_mask[..., ::ds] - ), - attn_mask=attn_mask, - memory=memory if i >= self.memory_layer else None, - memory_key_padding_mask=memory_key_padding_mask - if i >= self.memory_layer - else None, - ) - outputs.append(x) - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths - - def _get_attn_mask( - self, x: Tensor, chunk_size: int, left_context_chunks: int - ) -> Optional[Tensor]: - """ - Return None if chunk_size == -1, else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. - Args: - x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide - """ - if chunk_size <= 0: - return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all( - chunk_size * left_context_chunks - >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders) - ) - else: - left_context_chunks = 1000000 - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - c = t // chunk_size - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - c = t // chunk_size - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") - return attn_mask - - def _get_full_dim_output(self, outputs: List[Tensor]): - num_encoders = len(self.encoder_dim) - assert len(outputs) == num_encoders - output_dim = max(self.encoder_dim) - output_pieces = [outputs[-1]] - cur_dim = self.encoder_dim[-1] - for i in range(num_encoders - 2, -1, -1): - d = self.encoder_dim[i] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - - def streaming_forward( - self, - x: Tensor, - x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states - """ - outputs = [] - new_states = [] - layer_offset = 0 - - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - outputs.append(x) - new_states += new_layer_states - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - """ - states = [] - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - embed_dim = self.encoder_dim[i] - ds = self.downsampling_factor[i] - num_heads = self.num_heads[i] - key_dim = self.query_head_dim[i] * num_heads - value_dim = self.value_head_dim[i] * num_heads - downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( - device - ) - cached_nonlin_attn = torch.zeros( - 1, batch_size, downsample_left, nonlin_attn_head_dim - ).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - return states - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - memory_dim: int = -1, - attention_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - conv_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - const_attention_rate: FloatLike = ScheduledFloat( - (0.0, 0.25), (4000.0, 0.025), default=0 - ), - ff2_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - ff3_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - bypass_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.5), (4000.0, 0.02), default=0 - ), - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule( - embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 - ) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) - - self.self_attn2 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) - - if memory_dim > 0: - self.attn_weights = MultiheadAttentionWeights( - memory_dim, - embed_dim, - num_heads=num_heads, - head_dim=query_head_dim, - dropout=0.0, - ) - self.src_attn1 = Attention(memory_dim, embed_dim, num_heads, value_head_dim) - self.src_attn2 = Attention(memory_dim, embed_dim, num_heads, value_head_dim) - self.memory_balancer = Balancer( - embed_dim, - channel_dim=-1, - min_abs=0.015, - ) - - self.feed_forward1 = FeedforwardModule( - embed_dim, (feedforward_dim * 3) // 4, dropout - ) - - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule( - embed_dim, (feedforward_dim * 5) // 4, dropout - ) - - self.nonlin_attention = NonlinAttention( - embed_dim, hidden_channels=3 * embed_dim // 4 - ) - - self.conv_module1 = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) - - self.conv_module2 = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) - - # self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) - - self.norm = BiasNorm(embed_dim) - - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - - self.balancer1 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.2, - max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.balancer2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.1, - max_abs=4.0, - ) - - def get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 correponds to bypassing - # this module. - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value( - self.bypass_scale, - min=float(self.bypass_min), - max=float(self.bypass_max), - ) - layer_skip_rate = float(self.layer_skip_rate) - if layer_skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - return ans - - def get_sequence_dropout_mask( - self, x: Tensor, dropout_rate: float - ) -> Optional[Tensor]: - if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - memory: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # dropout rate for non-feedforward submodules - attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - - if memory is not None and hasattr(self, "attn_weights"): - src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask) - - src = src + self.feed_forward1(src) - - attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) - - if True: - selected_attn_weights = attn_weights[0:2] - if random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to( - selected_attn_weights.dtype - ) - selected_attn_weights = selected_attn_weights * ( - 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) - ) - selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights[0:1])) - - src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + ( - self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask - ) - - if memory is not None and hasattr(self, "attn_weights"): - src = src + self.sequence_dropout( - self.memory_balancer(self.src_attn1(memory, src_attn_weights)), - attention_skip_rate, - ) - - src = src + self.sequence_dropout( - self.conv_module1( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - float(self.conv_skip_rate), - ) - - src = src + self.sequence_dropout( - self.balancer_ff2(self.feed_forward2(src)), float(self.ff2_skip_rate) - ) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + ( - self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask - ) - - if memory is not None and hasattr(self, "attn_weights"): - src = src + self.sequence_dropout( - self.memory_balancer(self.src_attn2(memory, src_attn_weights)), - attention_skip_rate, - ) - - src = src + self.sequence_dropout( - self.conv_module2( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - float(self.conv_skip_rate), - ) - - src = src + self.sequence_dropout( - self.balancer_ff3(self.feed_forward3(src)), float(self.ff3_skip_rate) - ) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.15, length_factor=1.0 - ) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat( - (cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0, - ) - cur_begin = cur_end - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - memory: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) - memory_key_padding_mask: optionally the mask for padding of memory input (for source- - attention), of shape (batch_size, memory_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - output = src - - output = output * feature_mask - - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - - output = output * feature_mask - - return output - - -class BypassModule(nn.Module): - """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0, - ): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 correponds to bypassing - # this module. - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value( - self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) - ) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = ( - torch.rand((batch_size, 1), device=ans.device) - < straight_through_rate - ) - ans = torch.maximum(ans, mask.to(ans.dtype)) - - return ans - - def forward(self, src_orig: Tensor, src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike - ): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, downsample, dropout) - self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0.025) - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - memory: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) - memory_key_padding_mask: optionally the mask for padding of memory input (for source- - attention), of shape (batch_size, memory_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds, ::ds] - - src = self.encoder( - src, - chunk_size=chunk_size // ds, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - - def __init__(self, channels: int, downsample: int, dropout: FloatLike): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - - def __init__( - self, - embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0 - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0 - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= x.size(0) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - T = x.size(0) - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) - ) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> Tensor: - """Create positional encoding. - - Args: - x (torch.Tensor): Input tensor (time, batch, `*`). - - Returns: - positional embedding, of shape (1, 2*time-1, `*`). - - """ - self.extend_pe(x) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x.size(0) - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 - ) - - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be suffixient to fix the problem. - self.balance_keys = Balancer( - key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnosics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) - chunk_size - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - if not self.training or random.random() >= float(self.pos_emb_skip_rate): - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - if self.training and random.random() < 0.1: - # This is away of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 25.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=25.0, penalty=1.0e-04, name=self.name - ) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if random.random() < 0.001: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.info( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) - - -class Attention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim_in: the input embedding dimension - embed_dim_out: the output embedding dimension (normally the same as input) - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - - def __init__( - self, - embed_dim_in: int, - embed_dim_out: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim_in, num_heads * value_head_dim, bias=False) - - # Note we set bias to False so that input of 0 will have no effect - self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim_out, bias=False, initial_scale=0.05 - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len), - Expect attn_weights.sum(dim=-1) == 1. The input here is the value in the - original attention mechanism. - Returns: - a tensor with the same shape as x. - """ - (num_heads, batch_size, query_len, key_len) = attn_weights.shape - - x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim) - x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, key_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, query_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(query_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (query_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, ( - cached_val.shape[0], - left_context_len, - ) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] - - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val - - -class MultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head cross-attention weights. Allows src and target - to have different dims. - - Args: - key_embed_dim: number of channels of the thing that we'll project to - make the query (corresponds to source). e.g. 256 - query_embed_dim: number of channels of the thing that we'll project to - make the query (corresponds to target). e.g. 256 - num_heads: number of heads to compute weights for, e.g. 8 - head_dim: dimension of the query and key, per head. e.g. 24. - dropout: dropout probability for attn_output_weights. Default: 0.0. - """ - - def __init__( - self, - key_embed_dim: int, - query_embed_dim: int, - num_heads: int, - head_dim: int, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.key_embed_dim = key_embed_dim - self.query_embed_dim = query_embed_dim - self.num_heads = num_heads - self.head_dim = head_dim - self.dropout = dropout - self.name = None # will be overwritten in training code; for diagnostics. - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.query_in_proj = ScaledLinear( - query_embed_dim, - head_dim * num_heads, - bias=True, - initial_scale=head_dim**-0.25, - ) - - # weights produced by this module are invariant to adding a constant to - # the keys, so we don't need a bias for the keys. - self.key_in_proj = ScaledLinear( - key_embed_dim, - head_dim * num_heads, - bias=False, - initial_scale=head_dim**-0.25, - ) - - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - def forward( - self, - key: Tensor, - query: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - key: input of shape (key_len, batch_size, key_embed_dim) - query: input of shape (query_len, batch_size, query_embed_dim) - key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len) - """ - q = self.query_in_proj(query) - k = self.key_in_proj(key) - - head_dim = self.head_dim - num_heads = self.num_heads - - query_len, batch_size, _ = q.shape - key_len, _batch_size, _ = k.shape - assert _batch_size == batch_size - - k = self.whiten_keys(k) # does nothing in the forward pass. - - q = q.reshape(query_len, batch_size, num_heads, head_dim) - k = k.reshape(key_len, batch_size, num_heads, head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - if self.training and random.random() < 0.1: - # This is a way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 25.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=25.0, penalty=1.0e-04, name=self.name - ) - - assert attn_scores.shape == (num_heads, batch_size, query_len, key_len) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - key_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if random.random() < 0.001: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.info( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model.""" - - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) - - self.hidden_balancer = Balancer( - feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( - feedforward_dim, - embed_dim, - activation="SwooshL", - dropout_p=dropout, - dropout_shared_dim=0, - bias=True, - initial_scale=0.1, - ) - - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) - - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == ( - num_heads, - batch_size, - seq_len, - left_context_len + seq_len, - ) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, ( - cached_x.shape[2], - left_context_len, - ) - x_pad = torch.cat([cached_x, x], dim=2) - # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - x = x * y - - x = self.out_proj(x) - return x, cached_x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, - 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in - # the range 1 to 4, but sometimes, for some reason, for layer 0 the - # rms ends up being very large, between 50 and 100 for different channels. - # This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ( - ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) - if causal - else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - ) - - self.balancer2 = Balancer( - bottleneck_dim, - channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, - channels, - activation="SwooshR", - dropout_p=0.0, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=-1) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if ( - not torch.jit.is_scripting() - and not torch.jit.is_tracing() - and chunk_size >= 0 - ): - # Not support exporting a model for simulated streaming decoding - assert ( - self.causal - ), "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. - - Args: - x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) - x = x * s - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - return x, cache - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def forward(self, x): - return x * self.scale - - -def _test_zipformer_main(causal: bool = False): - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - memory_dim = 100 - - c = Zipformer2( - encoder_dim=(64, 96), - encoder_unmasked_dim=(48, 64), - num_heads=(4, 4), - causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,), - memory_dim=memory_dim, - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - memory=torch.randn(101, batch_size, memory_dim), - ) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main(False) - _test_zipformer_main(True) diff --git a/egs/librilight/SSL/zipformer/asr_datamodule.py b/egs/librilight/SSL/zipformer/asr_datamodule.py deleted file mode 120000 index b9313bffc..000000000 --- a/egs/librilight/SSL/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/asr_datamodule.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/beam_search.py b/egs/librilight/SSL/zipformer/beam_search.py deleted file mode 120000 index 3b02c21db..000000000 --- a/egs/librilight/SSL/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/dataset.py b/egs/librilight/SSL/zipformer/dataset.py deleted file mode 120000 index 5cd60d3b4..000000000 --- a/egs/librilight/SSL/zipformer/dataset.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/dataset.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/decode.py b/egs/librilight/SSL/zipformer/decode.py deleted file mode 100644 index 95643c5e1..000000000 --- a/egs/librilight/SSL/zipformer/decode.py +++ /dev/null @@ -1,1045 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao, -# Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from finetune import add_model_arguments, get_model, get_params - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - audio = batch["audio"].to(device) - padding_mask = batch["padding_mask"].to(device) - - encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(batch["supervisions"]["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"-context-score-{params.context_score}" - return {prefix: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["cuts"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append((sp.encode(line.strip()), 0.0)) - context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - dev_clean_cuts = librispeech.dev_clean_cuts() - dev_other_cuts = librispeech.dev_other_cuts() - - dev_clean_dl = librispeech.test_dataloaders( - dev_clean_cuts, - do_normalize=params.do_normalize, - ) - dev_other_dl = librispeech.test_dataloaders( - dev_other_cuts, - do_normalize=params.do_normalize, - ) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders( - test_clean_cuts, - do_normalize=params.do_normalize, - ) - test_other_dl = librispeech.test_dataloaders( - test_other_cuts, - do_normalize=params.do_normalize, - ) - - test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] - test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] - # test_sets = ["dev-clean", "dev-other"] - # test_dl = [dev_clean_dl, dev_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librilight/SSL/zipformer/decoder.py b/egs/librilight/SSL/zipformer/decoder.py deleted file mode 120000 index 96dbfc5cd..000000000 --- a/egs/librilight/SSL/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/decoder.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/encoder_interface.py b/egs/librilight/SSL/zipformer/encoder_interface.py deleted file mode 120000 index 30859c51b..000000000 --- a/egs/librilight/SSL/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/finetune.py b/egs/librilight/SSL/zipformer/finetune.py deleted file mode 100644 index 50dbd5f2d..000000000 --- a/egs/librilight/SSL/zipformer/finetune.py +++ /dev/null @@ -1,1552 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Yifan Yang, -# Daniel Povey) -# -# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -# For HuBERT model finetuning: -./hubert/finetune.py \ - --world-size 8 \ - --num-epochs 200 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir hubert/exp \ - --full-libri 0 \ - --max-duration 1000 - -It supports finetuning with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from decoder import Decoder -from hubert_ce import HubertModel -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * params.accum_grad - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - # hubert parameters - parser.add_argument( - "--label-rate", - type=float, - default=50, - ) - - parser.add_argument( - "--sample-rate", - type=float, - default=16000, - ) - - parser.add_argument( - "--extractor-mode", - type=str, - default="default", - help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group - norm with d groups in the first conv block, whereas layer_norm - has layer norms in every block (meant to use with normalize=True)""", - ) - - parser.add_argument( - "--conv-feature-layers", - type=str, - default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", - help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", - ) - - parser.add_argument( - "--conv-bias", type=bool, default=False, help="include bias in conv encoder" - ) - - parser.add_argument( - "--feature-grad-mult", - type=float, - default=1.0, - help="multiply feature extractor var grads by this", - ) - - # masking - parser.add_argument("--mask-length", type=int, default=10, help="mask_length") - - parser.add_argument( - "--mask-prob", - type=float, - default=0.65, - help="probability of replacing a token with mask", - ) - - parser.add_argument( - "--mask-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - default="static", - help="how to choose mask length", - ) - - parser.add_argument( - "--mask-other", - type=float, - default=0, - help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", - ) - - parser.add_argument( - "--no-mask-overlap", - type=bool, - default=False, - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-min-space", - type=int, - default=1, - help="min space between spans (if no overlap is enabled)", - ) - - # channel masking - parser.add_argument( - "--mask-channel-length", - type=int, - default=10, - help="length of the mask for features (channels)", - ) - - parser.add_argument( - "--mask-channel-prob", - type=float, - default=0.0, - help="probability of replacing a feature with 0", - ) - - parser.add_argument( - "--mask-channel-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - default="static", - help="how to choose mask length for channel masking", - ) - - parser.add_argument( - "--mask-channel-other", - type=float, - default=0, - help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", - ) - - parser.add_argument( - "--no-mask-channel-overlap", - type=bool, - default=False, - help="whether to allow channel masks to overlap", - ) - - parser.add_argument( - "--mask-channel-min-space", - type=int, - default=1, - help="min space between spans (if no overlap is enabled)", - ) - - # loss computation - parser.add_argument( - "--skip-masked", - type=bool, - default=False, - help="skip computing losses over masked frames", - ) - - parser.add_argument( - "--skip-nomask", - type=bool, - default=False, - help="skip computing losses over unmasked frames", - ) - - parser.add_argument( - "--checkpoint-activations", - type=bool, - default=False, - help="recompute activations and save memory for extra compute", - ) - - parser.add_argument( - "--pred-masked-weight", - type=float, - default=1, - help="weight for masked part in ssl loss", - ) - - parser.add_argument( - "--pred-nomask-weight", - type=float, - default=0, - help="weight for masked part in ssl loss", - ) - - parser.add_argument( - "--loss-weights", - type=float, - nargs="*", - default=[10], - help="weight for masked part in ssl loss", - ) - - # FP16 optimization - parser.add_argument( - "--required-seq-len-multiple", - type=int, - default=2, - help="pad the input to encoder such that the sequence length is divisible by multiple", - ) - - parser.add_argument( - "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" - ) - - parser.add_argument( - "--pos-enc-type", - type=str, - default="abs", - help="Positional encoding type to use in conformer", - ) - - parser.add_argument( - "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" - ) - - parser.add_argument( - "--dropout-input", - type=float, - default=0.0, - help="dropout to apply to the input (after feat extr)", - ) - - parser.add_argument( - "--dropout-features", - type=float, - default=0.0, - help="dropout to apply to the features (after feat extr)", - ) - - parser.add_argument( - "--num-classes", - type=int, - nargs="*", - default=[504], - help="""num class, a little larger than the number of cluster, - the largest is for padding, - and the value should be the multiple of 4, for faster computation""", - ) - - parser.add_argument( - "--untie-final-proj", - type=bool, - default=False, - help="use separate projection for each target", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=222, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="hubert/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--pretrained-dir", - type=str, - help="""The pretrained model dir. - It specifies the directory where the pretrained checkpoint is saved.""", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.001, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=100000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=100, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--sanity-check", - type=str2bool, - default=False, - help="Check if any of the batches in epoch 1 would cause OOM.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=100000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--accum-grad", - type=int, - default=1, - help="""update gradient when batch_idx_train % accum_grad == 0. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - - contains number of updates happen to the model so far across - epochs. - - - sub_batch_idx_train: It contains number of batch trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "sub_batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for pruned RNN-T loss - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - if hasattr(params, "pretrained_dir"): - logging.info(f"Loading {params.pretrained_dir}") - pretrained = torch.load(params.pretrained_dir) - encoder = HubertModel(params) - encoder.load_state_dict(pretrained["model"]) - else: - encoder = HubertModel(params) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `dataset.HubertAsrDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - audio = batch["audio"].to(device) - padding_mask = batch["padding_mask"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, num_frames = model( - x=audio, - padding_mask=padding_mask, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = num_frames.sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for sub_batch_idx, batch in enumerate(train_dl): - params.sub_batch_idx_train += 1 - batch_idx = sub_batch_idx // params.accum_grad - - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss / params.accum_grad).backward() - - if sub_batch_idx % params.accum_grad == params.accum_grad - 1: - params.batch_idx_train += 1 - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - else: - continue - - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - if batch_idx % params.accum_grad != params.accum_grad - 1: - optimizer.zero_grad() - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=0) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - librispeech = LibriSpeechAsrDataModule(args) - - train_cuts = ( - librispeech.train_all_shuf_cuts() - if params.full_libri - else librispeech.train_clean_100_cuts() - ) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librispeech.train_dataloaders( - train_cuts, - do_normalize=params.do_normalize, - sampler_state_dict=sampler_state_dict, - ) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - - valid_dl = librispeech.valid_dataloaders( - valid_cuts, - do_normalize=params.do_normalize, - ) - - if params.sanity_check and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `dataset.HubertAsrDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - audio = batch["audio"] - logging.info(f"audio shape: {audio.shape}") - - y = sp.encode(batch["supervisions"]["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librilight/SSL/zipformer/hubert_ce.py b/egs/librilight/SSL/zipformer/hubert_ce.py deleted file mode 120000 index 2b8482f78..000000000 --- a/egs/librilight/SSL/zipformer/hubert_ce.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/hubert_ce.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/joiner.py b/egs/librilight/SSL/zipformer/joiner.py deleted file mode 120000 index 587823e65..000000000 --- a/egs/librilight/SSL/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/joiner.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/model.py b/egs/librilight/SSL/zipformer/model.py deleted file mode 120000 index ca3daacca..000000000 --- a/egs/librilight/SSL/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/model.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/optim.py b/egs/librilight/SSL/zipformer/optim.py deleted file mode 120000 index bd2153ebf..000000000 --- a/egs/librilight/SSL/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/optim.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/pretrain.py b/egs/librilight/SSL/zipformer/pretrain.py deleted file mode 100644 index 5728dbe75..000000000 --- a/egs/librilight/SSL/zipformer/pretrain.py +++ /dev/null @@ -1,1366 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Yifan Yang, -# Daniel Povey) -# -# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -# For hubert model pretraining: -./zipformer/pretrain.py \ - --world-size 8 \ - --num-epochs 400 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 87.5 \ - --accum-grad 4 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from hubert_ce import HubertModel -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from optim import Eden, ScaledAdam -from ssl_datamodule import LibriLightDataModule -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * params.accum_grad - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - # hubert parameters - parser.add_argument( - "--label-rate", - type=float, - default=50, - ) - - parser.add_argument( - "--sample-rate", - type=float, - default=16000, - ) - - parser.add_argument( - "--extractor-mode", - type=str, - default="default", - help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group - norm with d groups in the first conv block, whereas layer_norm - has layer norms in every block (meant to use with normalize=True)""", - ) - - parser.add_argument( - "--conv-feature-layers", - type=str, - default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", - help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", - ) - - parser.add_argument( - "--conv-bias", type=bool, default=False, help="include bias in conv encoder" - ) - - parser.add_argument( - "--feature-grad-mult", - type=float, - default=1.0, - help="multiply feature extractor var grads by this", - ) - - # masking - parser.add_argument("--mask-length", type=int, default=10, help="mask_length") - - parser.add_argument( - "--mask-prob", - type=float, - default=0.65, - help="probability of replacing a token with mask", - ) - - parser.add_argument( - "--mask-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - default="static", - help="how to choose mask length", - ) - - parser.add_argument( - "--mask-other", - type=float, - default=0, - help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh", - ) - - parser.add_argument( - "--no-mask-overlap", - type=bool, - default=False, - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-min-space", - type=int, - default=1, - help="min space between spans (if no overlap is enabled)", - ) - - # channel masking - parser.add_argument( - "--mask-channel-length", - type=int, - default=10, - help="length of the mask for features (channels)", - ) - - parser.add_argument( - "--mask-channel-prob", - type=float, - default=0.0, - help="probability of replacing a feature with 0", - ) - - parser.add_argument( - "--mask-channel-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - default="static", - help="how to choose mask length for channel masking", - ) - - parser.add_argument( - "--mask-channel-other", - type=float, - default=0, - help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh", - ) - - parser.add_argument( - "--no-mask-channel-overlap", - type=bool, - default=False, - help="whether to allow channel masks to overlap", - ) - - parser.add_argument( - "--mask-channel-min-space", - type=int, - default=1, - help="min space between spans (if no overlap is enabled)", - ) - - # loss computation - parser.add_argument( - "--skip-masked", - type=bool, - default=False, - help="skip computing losses over masked frames", - ) - - parser.add_argument( - "--skip-nomask", - type=bool, - default=False, - help="skip computing losses over unmasked frames", - ) - - parser.add_argument( - "--checkpoint-activations", - type=bool, - default=False, - help="recompute activations and save memory for extra compute", - ) - - parser.add_argument( - "--pred-masked-weight", - type=float, - default=1, - help="weight for masked part in ssl loss", - ) - - parser.add_argument( - "--pred-nomask-weight", - type=float, - default=0, - help="weight for masked part in ssl loss", - ) - - parser.add_argument( - "--loss-weights", - type=float, - nargs="*", - default=[10], - help="weight for masked part in ssl loss", - ) - - # FP16 optimization - parser.add_argument( - "--required-seq-len-multiple", - type=int, - default=2, - help="pad the input to encoder such that the sequence length is divisible by multiple", - ) - - parser.add_argument( - "--attn-type", type=str, default="", help="if espnet use ESPNET MHA" - ) - - parser.add_argument( - "--pos-enc-type", - type=str, - default="abs", - help="Positional encoding type to use in conformer", - ) - - parser.add_argument( - "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" - ) - - parser.add_argument( - "--dropout-input", - type=float, - default=0.0, - help="dropout to apply to the input (after feat extr)", - ) - - parser.add_argument( - "--dropout-features", - type=float, - default=0.0, - help="dropout to apply to the features (after feat extr)", - ) - - parser.add_argument( - "--num-classes", - type=int, - nargs="*", - default=[504], - help="""num class, a little larger than the number of cluster, - the largest is for padding, - and the value should be the multiple of 4, for faster computation""", - ) - - parser.add_argument( - "--untie-final-proj", - type=bool, - default=False, - help="use separate projection for each target", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=400, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=10.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--warmup-batches", - type=float, - default=5000, - help="Eden warmup steps", - ) - - parser.add_argument( - "--warmup-start", - type=float, - default=0, - help="Eden warmup start learning rate", - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--sanity-check", - type=str2bool, - default=False, - help="Check if any of the batches in epoch 1 would cause OOM.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=100000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--accum-grad", - type=int, - default=4, - help="""update gradient when batch_idx_train % accum_grad == 0. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--max-sample-size", - type=float, - default=250000, - help="max sample size", - ) - - parser.add_argument( - "--min-sample-size", - type=float, - default=32000, - help="min sample size", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of updates happen to the model so far across - epochs. - - - sub_batch_idx_train: It contains number of batch trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "sub_batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_model(params: AttributeDict) -> nn.Module: - model = HubertModel(params) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `dataset.HubertDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - audio = batch["audio"].to(device) - padding_mask = batch["padding_mask"].to(device) - kmeans = batch["kmeans"].to(device) - - with torch.set_grad_enabled(is_training): - loss, num_masked_tokens, logging_output = model( - source=audio, target_list=[kmeans], padding_mask=padding_mask - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = num_masked_tokens - for item in logging_output: - info[item] = logging_output[item] - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for sub_batch_idx, batch in enumerate(train_dl): - params.sub_batch_idx_train += 1 - batch_idx = sub_batch_idx // params.accum_grad - - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - batch_size = batch["kmeans"].shape[0] - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss / params.accum_grad).backward() - - if sub_batch_idx % params.accum_grad == params.accum_grad - 1: - params.batch_idx_train += 1 - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - else: - continue - - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - if batch_idx % params.accum_grad != params.accum_grad - 1: - optimizer.zero_grad() - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden( - optimizer, - params.lr_batches, - params.lr_epochs, - params.warmup_batches, - params.warmup_start, - ) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - librilight = LibriLightDataModule(args) - - train_cuts = librilight.train_all_shuf_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if ( - c.duration < params.min_sample_size / params.sample_rate - or c.duration > params.max_sample_size / params.sample_rate - ): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librilight.train_dataloaders( - train_cuts, - sample_rate=params.sample_rate, - label_rate=params.label_rate, - random_crop=params.random_crop, - pad_audio=False, - num_classes=params.num_classes, - do_normalize=params.do_normalize, - sampler_state_dict=sampler_state_dict, - ) - - valid_cuts = librilight.dev_clean_cuts() - # valid_cuts += librilight.dev_other_cuts() - valid_cuts = valid_cuts.filter(remove_short_and_long_utt) - - valid_dl = librilight.valid_dataloaders( - valid_cuts, - sample_rate=params.sample_rate, - label_rate=params.label_rate, - random_crop=params.random_crop, - pad_audio=False, - num_classes=params.num_classes, - do_normalize=params.do_normalize, - ) - - if params.sanity_check and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `dataset.HubertDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - audio = batch["audio"] - logging.info(f"audio shape: {audio.shape}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriLightDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librilight/SSL/zipformer/scaling.py b/egs/librilight/SSL/zipformer/scaling.py deleted file mode 120000 index 24b661dfb..000000000 --- a/egs/librilight/SSL/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/scaling.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/ssl_datamodule.py b/egs/librilight/SSL/zipformer/ssl_datamodule.py deleted file mode 100644 index dc0dbec6c..000000000 --- a/egs/librilight/SSL/zipformer/ssl_datamodule.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import glob -import logging -import re -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from dataset import HubertDataset -from lhotse import CutSet, combine, load_manifest_lazy -from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriLightDataModule: - """ - DataModule for SSL experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in SSL - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - This class should be derived for specific corpora used in SSL tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR SSL related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/kmeans"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=float, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - group.add_argument( - "--do-normalize", - type=str2bool, - default=True, - help="whether to normalize the data", - ) - group.add_argument( - "--random-crop", - type=str2bool, - default=True, - help="audio sample rate", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sample_rate: float = 16000, - label_rate: float = 50, - random_crop: bool = True, - pad_audio: bool = False, - num_classes: list = [504], - do_normalize: bool = True, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = HubertDataset( - sample_rate=sample_rate, - label_rate=label_rate, - random_crop=random_crop, - pad_audio=pad_audio, - num_classes=num_classes, - do_normalize=do_normalize, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders( - self, - cuts_valid: CutSet, - sample_rate: float = 16000, - label_rate: float = 50, - random_crop: bool = True, - pad_audio: bool = False, - num_classes: list = [504], - do_normalize: bool = True, - ) -> DataLoader: - logging.info("About to create dev dataset") - validate = HubertDataset( - sample_rate=sample_rate, - label_rate=label_rate, - random_crop=random_crop, - pad_audio=pad_audio, - num_classes=num_classes, - do_normalize=do_normalize, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders( - self, - cuts: CutSet, - sample_rate: float = 16000, - label_rate: float = 50, - random_crop: bool = True, - pad_audio: bool = False, - num_classes: list = [504], - do_normalize: bool = True, - ) -> DataLoader: - logging.debug("About to create test dataset") - test = HubertDataset( - sample_rate=sample_rate, - label_rate=label_rate, - random_crop=random_crop, - pad_audio=pad_audio, - num_classes=num_classes, - do_normalize=do_normalize, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def small_cuts(self) -> CutSet: - logging.info("About to get small cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librilight_cuts_small.jsonl.gz" - ) - - @lru_cache() - def medium_cuts(self) -> CutSet: - logging.info("About to get medium cuts") - filenames = glob.glob( - f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz" - ) - pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz") - idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) - idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) - sorted_filenames = [f[1] for f in idx_filenames] - logging.info( - f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode" - ) - - return combine(load_manifest_lazy(p) for p in sorted_filenames) - - @lru_cache() - def large_cuts(self) -> CutSet: - logging.info("About to get large cuts") - filenames = glob.glob( - f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz" - ) - pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz") - idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) - idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) - sorted_filenames = [f[1] for f in idx_filenames] - logging.info( - f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode" - ) - - return combine(load_manifest_lazy(p) for p in sorted_filenames) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info("About to get the shuffled small, medium and large cuts") - small_cuts = self.small_cuts() - medium_cuts = self.medium_cuts() - large_cuts = self.large_cuts() - return CutSet.mux( - small_cuts, - medium_cuts, - large_cuts, - weights=[ - 122867, # len(small_cuts) - 1104071, # len(medium_cuts) - 11012085, # len(large_cuts) - ], - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) diff --git a/egs/librilight/SSL/zipformer/utils.py b/egs/librilight/SSL/zipformer/utils.py deleted file mode 120000 index 119992bdb..000000000 --- a/egs/librilight/SSL/zipformer/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/utils.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/wav2vec2_module.py b/egs/librilight/SSL/zipformer/wav2vec2_module.py deleted file mode 120000 index 81ad701e4..000000000 --- a/egs/librilight/SSL/zipformer/wav2vec2_module.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/wav2vec2_module.py \ No newline at end of file diff --git a/egs/librilight/SSL/zipformer/zipformer.py b/egs/librilight/SSL/zipformer/zipformer.py deleted file mode 120000 index 5b3da8cd5..000000000 --- a/egs/librilight/SSL/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/SSL/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/.gitignore b/egs/librispeech/ASR/.gitignore old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/.vscode/launch.json b/egs/librispeech/ASR/.vscode/launch.json new file mode 100755 index 000000000..7888a8de6 --- /dev/null +++ b/egs/librispeech/ASR/.vscode/launch.json @@ -0,0 +1,246 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug Training (Quick Test)", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/train.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--exp-dir", "./conformer_ctc/exp", + "--lang-dir", "./data/lang_bpe_5000", + "--world-size", "1", + "--num-epochs", "1", + "--att-rate", "0.0", + "--max-duration", "20", + "--start-epoch", "1", + "--num-epochs", "1", + "--valid-interval", "1", + "--validation-decoding-method", "greedy" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "0", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug Training (Full Config)", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/train.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--sanity-check", "false", + "--exp-dir", "./conformer_ctc/exp", + "--lang-dir", "./data/lang_bpe_5000", + "--world-size", "1", + "--num-epochs", "50", + "--start-epoch", "0", + "--att-rate", "0.0", + "--num-buckets", "30", + "--seed", "42", + "--valid-interval", "10", + "--validation-decoding-method", "greedy", + "--validation-search-beam", "20", + "--validation-output-beam", "8" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "3", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug Validation Only", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/decode.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--exp-dir", "./conformer_ctc/exp", + "--lang-dir", "./data/lang_bpe_5000", + "--max-duration", "100", + "--method", "ctc-decoding", + "--epoch", "1", + "--avg", "1" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "0", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug Data Loading", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/asr_datamodule.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [], + "env": { + "CUDA_VISIBLE_DEVICES": "0", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug Training with Augmentation", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/train.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--exp-dir", "./conformer_ctc/exp", + "--lang-dir", "./data/lang_bpe_5000", + "--world-size", "1", + "--num-epochs", "5", + "--att-rate", "0.0", + "--max-duration", "100", + "--enable-spec-aug", "True", + "--enable-musan", "True", + "--enable-rir", "True", + "--enable-cutmix", "True", + "--enable-concatenate", "True", + "--spec-aug-time-warp-factor", "80", + "--spec-aug-num-frame-masks", "2", + "--spec-aug-frame-mask-max-length", "10", + "--musan-cuts-path", "./data/musan_cuts.jsonl.gz", + "--rir-cuts-path", "./data/rir_cuts.jsonl.gz", + "--valid-interval", "1", + "--validation-decoding-method", "greedy" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "0", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug Phone-based Training", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/train.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--exp-dir", "./conformer_ctc/exp_phone", + "--lang-dir", "./data/lang_phone", + "--world-size", "1", + "--num-epochs", "5", + "--att-rate", "0.0", + "--max-duration", "100", + "--valid-interval", "1", + "--validation-decoding-method", "greedy" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "0", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug CTC Decoding (from decode.sh)", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/decode.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--method", "ctc-decoding", + "--max-duration", "20", + "--epoch", "12", + "--avg", "3", + "--exp-dir", "./conformer_ctc/exp/models", + "--lang-dir", "./data/lang_bpe_5000" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "2", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug CTC Decoding (Quick Test)", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/decode.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--method", "ctc-decoding", + "--max-duration", "5", + "--epoch", "12", + "--avg", "1", + "--exp-dir", "./conformer_ctc/exp", + "--lang-dir", "./data/lang_bpe_5000" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "2", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug Attention Decoder", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/decode.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--method", "attention-decoder", + "--max-duration", "10", + "--epoch", "12", + "--avg", "3", + "--num-paths", "100", + "--exp-dir", "./conformer_ctc/exp", + "--lang-dir", "./data/lang_bpe_5000", + "--lm-dir", "./data/lm" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "2", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + }, + { + "name": "Debug Whole Lattice Rescoring", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/conformer_ctc/decode.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": [ + "--method", "whole-lattice-rescoring", + "--max-duration", "10", + "--epoch", "12", + "--avg", "3", + "--exp-dir", "./conformer_ctc/exp", + "--lang-dir", "./data/lang_bpe_5000", + "--lm-dir", "./data/lm" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "2", + "PYTHONFAULTHANDLER": "1", + "PYTHONPATH": "/home/hdd2/jenny/ASRToolkit/icefall" + }, + "stopOnEntry": false + } + ] + } \ No newline at end of file diff --git a/egs/librispeech/ASR/.vscode/settings.json b/egs/librispeech/ASR/.vscode/settings.json new file mode 100755 index 000000000..d3d0699e5 --- /dev/null +++ b/egs/librispeech/ASR/.vscode/settings.json @@ -0,0 +1,22 @@ +{ + "workbench.colorCustomizations": { + "activityBar.activeBackground": "#e6b7c3", + "activityBar.background": "#e6b7c3", + "activityBar.foreground": "#15202b", + "activityBar.inactiveForeground": "#15202b99", + "activityBarBadge.background": "#498e31", + "activityBarBadge.foreground": "#e7e7e7", + "commandCenter.border": "#15202b99", + "sash.hoverBorder": "#e6b7c3", + "statusBar.background": "#d991a3", + "statusBar.foreground": "#15202b", + "statusBarItem.hoverBackground": "#cc6b83", + "statusBarItem.remoteBackground": "#d991a3", + "statusBarItem.remoteForeground": "#15202b", + "titleBar.activeBackground": "#d991a3", + "titleBar.activeForeground": "#15202b", + "titleBar.inactiveBackground": "#d991a399", + "titleBar.inactiveForeground": "#15202b99" + }, + "peacock.remoteColor": "#d991a3" +} \ No newline at end of file diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/RESULTS-100hours.md b/egs/librispeech/ASR/RESULTS-100hours.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc/Testing/Temporary/LastTest.log b/egs/librispeech/ASR/conformer_ctc/Testing/Temporary/LastTest.log new file mode 100644 index 000000000..0ae56f3e5 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/Testing/Temporary/LastTest.log @@ -0,0 +1,3 @@ +Start testing: Aug 21 16:57 KST +---------------------------------------------------------- +End testing: Aug 21 16:57 KST diff --git a/egs/librispeech/ASR/conformer_ctc/__init__.py b/egs/librispeech/ASR/conformer_ctc/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py old mode 100644 new mode 100755 index ea793ce2f..00d438730 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -49,12 +49,12 @@ class Conformer(Transformer): d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, - num_encoder_layers: int = 12, + num_encoder_layers: int = 16, num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, - vgg_frontend: bool = False, + vgg_frontend: bool = True, use_feat_batchnorm: Union[float, bool] = 0.1, ) -> None: super(Conformer, self).__init__( diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 7e0bf5b7b..a9020db2a 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -42,7 +42,6 @@ from icefall.decode import ( rescore_with_whole_lattice, ) from icefall.env import get_env_info -from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, @@ -130,21 +129,21 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="conformer_ctc/exp", + default="conformer_ctc/exp/models", help="The experiment dir", ) parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe_500", - help="The lang dir", + default="data/lang_bpe_5000", + help="The lang dir (using BPE)", ) parser.add_argument( "--lm-dir", type=str, - default="data/lm", + default="/home/hdd1/jenny/lm", help="""The n-gram LM dir. It should contain either G_4_gram.pt or G_4_gram.fst.txt """, @@ -217,14 +216,14 @@ def get_params() -> AttributeDict: "vgg_frontend": False, "use_feat_batchnorm": True, "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, + "nhead": 4, + "attention_dim": 256, + "num_decoder_layers": 0, # parameters for decoding "search_beam": 20, "output_beam": 8, "min_active_states": 30, - "max_active_states": 10000, + "max_active_states": 1000, "use_double_scores": True, "env_info": get_env_info(), } @@ -294,6 +293,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ + if HLG is not None: device = HLG.device else: @@ -304,10 +304,12 @@ def decode_one_batch( # at entry, feature is (N, T, C) supervisions = batch["supervisions"] - + + # Step 1: Model forward pass nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) # nnet_output is (N, T, C) - + + # Step 2: Supervision segments preparation supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -317,6 +319,14 @@ def decode_one_batch( 1, ).to(torch.int32) + + # Ensure supervision segments don't exceed nnet_output length + max_allowed_frames = nnet_output.size(1) + supervision_segments[:, 2] = torch.clamp(supervision_segments[:, 2], max=max_allowed_frames) + + # CRITICAL FIX: k2.DenseFsaVec requires supervision_segments to be on CPU + supervision_segments = supervision_segments.cpu() + if H is None: assert HLG is not None decoding_graph = HLG @@ -324,7 +334,8 @@ def decode_one_batch( assert HLG is None assert bpe_model is not None decoding_graph = H - + + # Step 3: Lattice generation lattice = get_lattice( nnet_output=nnet_output, decoding_graph=decoding_graph, @@ -337,9 +348,11 @@ def decode_one_batch( ) if params.method == "ctc-decoding": + # Step 4: CTC decoding best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs # since we are using H, not HLG here. # @@ -351,6 +364,7 @@ def decode_one_batch( # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] hyps = [s.split() for s in hyps] + key = "ctc-decoding" return {key: hyps} @@ -523,9 +537,17 @@ def decode_dataset( num_batches = "?" results = defaultdict(list) + + logging.info(f"Starting decode with {num_batches} batches") + for batch_idx, batch in enumerate(dl): + + logging.info(f"Processing batch {batch_idx}/{num_batches}") + texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + logging.info(f"Batch {batch_idx}: {len(texts)} cuts, cut_ids: {cut_ids[:3]}...") hyps_dict = decode_one_batch( params=params, @@ -536,11 +558,11 @@ def decode_dataset( bpe_model=bpe_model, batch=batch, word_table=word_table, - G=G, sos_id=sos_id, eos_id=eos_id, + G=G, ) - + if hyps_dict is not None: for lm_scale, hyps in hyps_dict.items(): this_batch = [] @@ -550,6 +572,35 @@ def decode_dataset( this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) + + # Log ground truth vs predicted examples for the first method only + if lm_scale == list(hyps_dict.keys())[0]: # Only log for the first decoding method + # Log a few examples from this batch + num_examples = min(3, len(texts)) # Show up to 3 examples per batch + if num_examples > 0: + logging.info(f"=== DECODE EXAMPLES - Batch {batch_idx} ===") + for i in range(num_examples): + cut_id = cut_ids[i] + ref_text = texts[i] + hyp_text = " ".join(hyps[i]) + + logging.info(f"Example {i+1} (ID: {cut_id}):") + logging.info(f" REF: {ref_text}") + logging.info(f" HYP: {hyp_text}") + + # Simple accuracy check + ref_words = ref_text.split() + hyp_words = hyps[i] + if ref_words == hyp_words: + logging.info(f" --> ✅ PERFECT MATCH ({len(ref_words)} words)") + else: + # Calculate simple word error rate for this utterance + import difflib + matcher = difflib.SequenceMatcher(None, ref_words, hyp_words) + word_errors = len(ref_words) + len(hyp_words) - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / len(ref_words) * 100) if len(ref_words) > 0 else 0 + logging.info(f" --> ❌ WER: {utt_wer:.1f}% (REF: {len(ref_words)} words, HYP: {len(hyp_words)} words)") + logging.info("=" * 50) else: assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] @@ -563,10 +614,12 @@ def decode_dataset( num_cuts += len(texts) - if batch_idx % 100 == 0: + if batch_idx % 10 == 0: # Log more frequently for validation batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info(f"[VALIDATION] batch {batch_str}, cuts processed: {num_cuts}, " + f"cuts in this batch: {len(texts)}") + + logging.info(f"Completed decode_dataset with {num_cuts} total cuts processed") return results @@ -580,17 +633,23 @@ def save_results( enable_log = False else: enable_log = True + + # Create results directory if it doesn't exist + results_dir = params.exp_dir / "results" + results_dir.mkdir(exist_ok=True) + test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + # Save transcripts in results folder + recog_path = results_dir / f"recogs-{test_set_name}-{key}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + # ref/hyp pairs - also save in results folder + errs_filename = results_dir / f"errs-{test_set_name}-{key}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=enable_log @@ -601,7 +660,8 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + # Save WER summary in results folder + errs_info = results_dir / f"wer-summary-{test_set_name}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -613,6 +673,9 @@ def save_results( s += "{}\t{}{}\n".format(key, val, note) note = "" logging.info(s) + + # Return WER results for external use + return dict(test_set_wers) @torch.no_grad() @@ -631,9 +694,11 @@ def main(): logging.info("Decoding started") logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # For BPE mode: read vocab size from tokens.txt + tokens_file = params.lang_dir / "tokens.txt" + with open(tokens_file, 'r', encoding='utf-8') as f: + num_classes = len(f.readlines()) + max_token_id = num_classes - 1 device = torch.device("cpu") if torch.cuda.is_available(): @@ -654,6 +719,16 @@ def main(): params.sos_id = sos_id params.eos_id = eos_id + # Create BPE word table from tokens.txt + word_table = {} + with open(tokens_file, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + parts = line.strip().split() + if len(parts) >= 2: + token, idx = parts[0], parts[1] + word_table[int(idx)] = token + if params.method == "ctc-decoding": HLG = None H = k2.ctc_topo( @@ -684,7 +759,8 @@ def main(): logging.info("Loading G_4_gram.fst.txt") logging.warning("It may take 8 minutes.") with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] + # For BPE mode: use a default disambig ID (assuming #0 maps to ID 0) + first_word_disambig_id = 0 # This should be adjusted based on your BPE vocab G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so @@ -779,16 +855,10 @@ def main(): args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + # Get all test dataloaders (LibriSpeech + CHiME-4) + all_test_dls = librispeech.all_test_dataloaders() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): + for test_set_name, test_dl in all_test_dls.items(): results_dict = decode_dataset( dl=test_dl, params=params, @@ -797,13 +867,13 @@ def main(): HLG=HLG, H=H, bpe_model=bpe_model, - word_table=lexicon.word_table, + word_table=word_table, G=G, sos_id=sos_id, eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results(params=params, test_set_name=test_set_name, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 828106f41..f173a8042 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -38,6 +38,7 @@ import k2 import torch import torch.multiprocessing as mp import torch.nn as nn +import sentencepiece as spm from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.cut import Cut @@ -47,6 +48,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam +from decode import decode_dataset, save_results from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint @@ -55,14 +57,19 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, + load_averaged_model, MetricsTracker, encode_supervisions, setup_logger, str2bool, ) +# Global counter for validation samples to control terminal logging frequency +_VALIDATION_SAMPLE_COUNTER = 0 + def get_parser(): parser = argparse.ArgumentParser( @@ -93,7 +100,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=78, + default=100, help="Number of epochs to train.", ) @@ -110,7 +117,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="conformer_ctc/exp", + default="./conformer_ctc/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -120,7 +127,17 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe_500", + default="./data/lang_phone", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--bpe-dir", + type=str, + default="./data/lang_bpe_5000", help="""The lang dir It contains language related input files such as "lexicon.txt" @@ -139,7 +156,7 @@ def get_parser(): parser.add_argument( "--num-decoder-layers", type=int, - default=6, + default=0, help="""Number of decoder layer of transformer decoder. Setting this to 0 will not create the decoder at all (pure CTC model) """, @@ -152,13 +169,84 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--warm-step", + type=int, + default=30000, + help="Number of warmup steps for Noam optimizer. " + "Recommended: 30000 (with data aug), 15000-20000 (without data aug)", + ) + parser.add_argument( "--seed", type=int, default=42, help="The seed for random generators intended for reproducibility", ) - + parser.add_argument( + "--sanity-check", + type=str2bool, + default=True, + help="About Sanity check process", + ) + + parser.add_argument( + "--method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - ctc-decoding: CTC greedy search or beam search. + - nbest-rescoring: Use N-best list for LM rescoring. + - whole-lattice-rescoring: Use whole lattice for LM rescoring. + - attention-decoder: Use attention decoder rescoring. + - rnn-lm: Use RNN LM for rescoring. + """, + ) + + parser.add_argument( + "--enable-validation", + type=str2bool, + default=True, + help="Enable validation during training. Set to False to disable validation completely.", + ) + + parser.add_argument( + "--valid-interval", + type=int, + default=3000, + help="Run validation every N batches. Increase this to validate less frequently.", + ) + + parser.add_argument( + "--validation-decoding-method", + type=str, + default="greedy", + choices=["greedy", "beam"], + help="Decoding method for validation: 'greedy' for faster validation, 'beam' for more accurate WER.", + ) + + parser.add_argument( + "--validation-search-beam", + type=float, + default=10.0, + help="Search beam size for validation decoding (only used with beam search).", + ) + + parser.add_argument( + "--validation-output-beam", + type=float, + default=5.0, + help="Output beam size for validation decoding (only used with beam search).", + ) + + parser.add_argument( + "--validation-skip-wer", + type=str2bool, + default=False, + help="Skip WER computation during validation for faster validation (only compute loss).", + ) + return parser @@ -232,20 +320,25 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, + "valid_interval": 3000, # Default value, will be overridden by args # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, + "attention_dim": 256, + "nhead": 4, # parameters for loss "beam_size": 10, "reduction": "sum", "use_double_scores": True, + # parameters for decoding/validation + "search_beam": 20.0, + "output_beam": 8.0, + "min_active_states": 30, + "max_active_states": 10000, # parameters for Noam "weight_decay": 1e-6, - "warm_step": 80000, + "warm_step": 30000, "env_info": get_env_info(), } ) @@ -283,7 +376,18 @@ def load_checkpoint_if_available( if params.start_epoch <= 0: return - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + # First try to find checkpoint in models directory + models_dir = params.exp_dir / "models" + filename = models_dir / f"epoch-{params.start_epoch-1}.pt" + + # If not found in models directory, try the old location for backward compatibility + if not filename.exists(): + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + + if not filename.exists(): + logging.warning(f"Checkpoint not found at {filename}") + return + saved_params = load_checkpoint( filename, model=model, @@ -310,6 +414,9 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, rank: int = 0, + suffix: str = "", + wer_value: Optional[float] = None, + step: Optional[int] = None, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -318,10 +425,27 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + wer_value: + WER value to include in filename (optional). + step: + Training step to include in filename instead of epoch (optional). """ if rank != 0: return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + + # Create models directory if it doesn't exist + models_dir = params.exp_dir / "models" + models_dir.mkdir(exist_ok=True) + + if suffix: + # Use step instead of epoch for validation checkpoints + epoch_or_step = step if step is not None else params.cur_epoch + if wer_value is not None: + filename = models_dir / f"step-{epoch_or_step}-{suffix}-wer{wer_value:.2f}.pt" + else: + filename = models_dir / f"step-{epoch_or_step}-{suffix}.pt" + else: + filename = models_dir / f"epoch-{params.cur_epoch}.pt" save_checkpoint_impl( filename=filename, model=model, @@ -332,12 +456,16 @@ def save_checkpoint( ) if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" + best_train_filename = models_dir / "best-train-loss.pt" copyfile(src=filename, dst=best_train_filename) if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" + best_valid_filename = models_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) + + logging.info(f"Checkpoint saved successfully to {filename}") + # Remove the print statement that might be causing issues + # print("Saving All Done!") def compute_loss( @@ -398,9 +526,17 @@ def compute_loss( dense_fsa_vec = k2.DenseFsaVec( nnet_output, supervision_segments, - allow_truncate=params.subsampling_factor - 1, + allow_truncate=max(params.subsampling_factor - 1, 10), + # allow_truncate=0 ) - + # print("nnet_output shape: ", nnet_output.shape) + # print("supervisions: ", supervisions) + # print("supervision_segments: ", supervision_segments) + # print("graph_compiler: ", graph_compiler) + # Remove assertion that causes issues with subsampling + # assert supervision_segments[:, 2].max() <= nnet_output.size(1), \ + # "supervision_segments length exceeds nnet_output length" + ctc_loss = k2.ctc_loss( decoding_graph=decoding_graph, dense_fsa_vec=dense_fsa_vec, @@ -435,12 +571,11 @@ def compute_loss( assert loss.requires_grad == is_training + info = MetricsTracker() info["frames"] = supervision_segments[:, 2].sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - + info["att_loss"] = att_loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa @@ -461,32 +596,194 @@ def compute_validation_loss( graph_compiler: BpeCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, + epoch: int = 1, + quick_validation: bool = True, # Add option for quick validation + rank: int = 0, # Add rank parameter + tb_writer: Optional[SummaryWriter] = None, # Add TensorBoard writer parameter ) -> MetricsTracker: - """Run the validation process.""" + + model.eval() + + with torch.no_grad(): + device = next(model.parameters()).device + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info - tot_loss = MetricsTracker() + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info + logging.info("Validation loss computation completed") - if world_size > 1: - tot_loss.reduce(loss.device) + # Always compute WER for analysis + logging.info("Starting WER computation...") + + # Use the existing graph_compiler instead of creating a new one + # to ensure device compatibility in DDP training + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + # Read vocab size from tokens.txt + tokens_file = params.lang_dir / "tokens.txt" + with open(tokens_file, 'r', encoding='utf-8') as f: + vocab_size = len(f.readlines()) + max_token_id = vocab_size - 1 - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value + # WER calculation with proper device handling + if params.att_rate == 0.0: + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False - return tot_loss + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + # For BPE mode, create a simple word table from tokens + if "lang_bpe" in str(params.lang_dir): + # Read tokens and create a simple word table mapping + tokens_file = params.lang_dir / "tokens.txt" + if tokens_file.exists(): + word_table = {} + with open(tokens_file, 'r') as f: + for line in f: + if line.strip(): + parts = line.strip().split() + if len(parts) >= 2: + token, idx = parts[0], parts[1] + word_table[token] = int(idx) + else: + word_table = None + else: + # Phone mode: use lexicon word table + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + + + + # Use validation-specific decoding parameters + if params.validation_decoding_method == "greedy": + logging.info("Starting decode_dataset with GREEDY decoding...") + # Override beam parameters for greedy decoding + original_search_beam = params.search_beam + original_output_beam = params.output_beam + params.search_beam = 1.0 # Greedy = beam size 1 + params.output_beam = 1.0 + else: + logging.info(f"Starting decode_dataset with BEAM search (search_beam={params.validation_search_beam}, output_beam={params.validation_output_beam})...") + # Use validation-specific beam parameters + original_search_beam = params.search_beam + original_output_beam = params.output_beam + params.search_beam = params.validation_search_beam + params.output_beam = params.validation_output_beam + + try: + results_dict = decode_dataset( + dl=valid_dl, + params=params, + model=model, + rnn_lm_model=None, # For CTC validation, we don't use RNN LM + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=word_table, + sos_id=sos_id, + eos_id=eos_id, + ) + + except Exception as e: + logging.error(f"decode_dataset failed: {e}") + logging.error("Skipping WER computation for this validation") + # Restore original beam parameters + params.search_beam = original_search_beam + params.output_beam = original_output_beam + + logging.info(f"Validation loss: {loss_value:.4f}") + return tot_loss, None + + # Restore original beam parameters + params.search_beam = original_search_beam + params.output_beam = original_output_beam + + logging.info("Starting save_results...") + + wer_results = save_results(params=params, test_set_name=f"epoch_{epoch}_validation", results_dict=results_dict) + + # Log WER results + if wer_results: + for method, wer_value in wer_results.items(): + logging.info(f"Dataset-level WER ({method}): {wer_value:.2f}% (total errors/total words)") + # Log each WER method to TensorBoard + if rank == 0 and tb_writer is not None: + tb_writer.add_scalar(f"validation/wer_{method}", wer_value, params.batch_idx_train) + else: + logging.info("Validation WER: N/A") + + # Log some example predictions vs ground truth for inspection + log_prediction_examples(results_dict, max_examples=3) + + # Log examples to TensorBoard if available + if rank == 0 and tb_writer is not None: + log_validation_examples_to_tensorboard(results_dict, tb_writer, params.batch_idx_train, max_examples=5) + + # Calculate overall WER statistics if we have results + overall_wer = None + if wer_results: + # Find the main WER method (usually the first one or the one with 'wer' in the name) + main_wer_key = None + for key in wer_results.keys(): + if 'wer' in key.lower() or 'word_error_rate' in key.lower(): + main_wer_key = key + break + + if main_wer_key is None and wer_results: + # If no specific WER key found, use the first one + main_wer_key = list(wer_results.keys())[0] + + if main_wer_key: + overall_wer = wer_results[main_wer_key] + logging.info(f"Main dataset-level WER ({main_wer_key}): {overall_wer:.2f}% (total errors/total words)") + # Log the main/total WER to TensorBoard + if rank == 0 and tb_writer is not None: + tb_writer.add_scalar("validation/total_wer", overall_wer, params.batch_idx_train) + tb_writer.add_scalar("validation/wer_dataset_level", overall_wer, params.batch_idx_train) + + # Final logging of validation results + logging.info(f"Validation loss: {loss_value:.4f}") + if overall_wer is not None: + logging.info(f"Total validation WER: {overall_wer:.2f}% (dataset-level)") + # Log the final total WER to TensorBoard + if rank == 0 and tb_writer is not None: + tb_writer.add_scalar("validation/loss", loss_value, params.batch_idx_train) + tb_writer.add_scalar("validation/total_wer", overall_wer, params.batch_idx_train) + else: + logging.info("Validation WER: N/A") + + return tot_loss, overall_wer def train_one_epoch( @@ -498,6 +795,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -563,21 +861,72 @@ def train_one_epoch( ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( + if batch_idx > 0 and batch_idx % params.valid_interval == 0 and params.enable_validation: + logging.info(f"Computing validation loss (rank {rank})") + + + # Use quick validation for frequent checks, full validation less frequently + quick_val = (params.batch_idx_train % (params.valid_interval * 5) != 0) + valid_info, validation_wer = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, + epoch=params.cur_epoch, + quick_validation=quick_val, + rank=rank, + tb_writer=tb_writer, ) + + + # Log validation results with WER if available + if validation_wer is not None: + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}, WER: {validation_wer:.2f}%") + else: + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + + # Save checkpoint after validation (only rank 0) + if rank == 0: + logging.info(f"Saving checkpoint after validation at batch {batch_idx}") + try: + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + suffix=f"val-{batch_idx}", + wer_value=validation_wer, + step=batch_idx, + ) + logging.info(f"Checkpoint saved successfully for batch {batch_idx}") + except Exception as e: + logging.error(f"Failed to save checkpoint: {e}") + # Continue training even if checkpoint saving fails model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + + if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) + + # Write WER to TensorBoard if validation results file exists and contains WER + wer_summary_file = params.exp_dir / f"wer-summary-epoch_{params.cur_epoch}_validation.txt" + if wer_summary_file.exists(): + try: + with open(wer_summary_file, 'r') as f: + lines = f.readlines() + for line in lines[1:]: # Skip header line + if line.strip(): + parts = line.strip().split('\t') + if len(parts) >= 2: + method_name = parts[0] + wer_value = float(parts[1]) + tb_writer.add_scalar(f"train/valid_WER_{method_name}", wer_value, params.batch_idx_train) + except Exception as e: + logging.warning(f"Could not log WER to TensorBoard: {e}") + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -607,6 +956,7 @@ def run(rank, world_size, args): setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") + logging.info(f"Warmup steps: {params.warm_step}") logging.info(params) if args.tensorboard and rank == 0: @@ -614,10 +964,6 @@ def run(rank, world_size, args): else: tb_writer = None - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) @@ -629,6 +975,11 @@ def run(rank, world_size, args): sos_token="", eos_token="", ) + # Read vocab size from tokens.txt + tokens_file = params.lang_dir / "tokens.txt" + with open(tokens_file, 'r', encoding='utf-8') as f: + num_classes = len(f.readlines()) + max_token_id = num_classes - 1 elif "lang_phone" in str(params.lang_dir): assert params.att_rate == 0, ( "Attention decoder training does not support phone lang dirs " @@ -641,6 +992,9 @@ def run(rank, world_size, args): "Set --num-decoder-layers=0 for pure CTC training when using " "a phone-based lang dir." ) + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank graph_compiler = CtcTrainingGraphCompiler( lexicon, device=device, @@ -671,7 +1025,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = Noam( model.parameters(), @@ -706,17 +1060,22 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders(train_cuts) + # Use only dev_clean for faster validation (dev_other can be added later) valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() + # valid_cuts += librispeech.dev_other_cuts() # Comment out for faster validation valid_dl = librispeech.valid_dataloaders(valid_cuts) + + logging.info(f"Validation set size: {len(valid_cuts)} utterances") - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) + if params.sanity_check: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + else: pass for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) @@ -741,6 +1100,7 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) save_checkpoint( @@ -796,13 +1156,252 @@ def scan_pessimistic_batches_for_oom( raise +def log_prediction_examples(results_dict, max_examples=5, force_log=False): + """ + Log a few examples of ground truth vs predicted text for validation inspection. + Only logs to terminal every 50 validation samples to reduce clutter. + + Args: + results_dict: Dictionary containing decoding results + max_examples: Maximum number of examples to log + force_log: Force logging regardless of sample counter + """ + global _VALIDATION_SAMPLE_COUNTER + + if not results_dict: + return + + # Get the first method's results (usually there's only one method in validation) + first_method = list(results_dict.keys())[0] + results = results_dict[first_method] + + if not results: + return + + # Update the validation sample counter + _VALIDATION_SAMPLE_COUNTER += len(results) + + # Only log to terminal every 50 samples (or when forced) + should_log_to_terminal = force_log or (_VALIDATION_SAMPLE_COUNTER % 50 == 0) or (_VALIDATION_SAMPLE_COUNTER <= 50) + + if not should_log_to_terminal: + # Still compute and log basic statistics, just not the detailed examples + total_sample_wer = 0 + valid_samples = 0 + + for result in results: + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + ref_word_list = ref_text.split() + hyp_word_list = hyp_text.split() + + if len(ref_word_list) > 0: + import difflib + matcher = difflib.SequenceMatcher(None, ref_word_list, hyp_word_list) + word_errors = len(ref_word_list) + len(hyp_word_list) - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / len(ref_word_list)) * 100 + total_sample_wer += utt_wer + valid_samples += 1 + + # Log summary info only + if valid_samples > 0: + avg_example_wer = total_sample_wer / valid_samples + logging.info(f"Validation batch processed: {valid_samples} samples " + f"(total samples processed: {_VALIDATION_SAMPLE_COUNTER}, detailed examples every 50 samples)") + return + + # Full detailed logging when we hit the 50-sample threshold + logging.info(f"Detailed validation examples (sample #{_VALIDATION_SAMPLE_COUNTER - len(results) + 1}-{_VALIDATION_SAMPLE_COUNTER}):") + + # Select diverse examples: some short, some long, some with errors, some perfect + selected_examples = [] + + # Try to get diverse examples by length and error type + perfect_matches = [] + error_cases = [] + + for result in results: + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + if ref_text.split() == hyp_text.split(): + perfect_matches.append(result) + else: + error_cases.append(result) + + # Mix perfect matches and error cases + selected_examples = error_cases[:max_examples-1] + perfect_matches[:1] + if len(selected_examples) < max_examples: + selected_examples.extend(results[:max_examples - len(selected_examples)]) + + selected_examples = selected_examples[:max_examples] + + logging.info("=" * 80) + logging.info(f"VALIDATION EXAMPLES (showing {len(selected_examples)} samples):") + logging.info("=" * 80) + + total_sample_wer = 0 + valid_samples = 0 + + for i, result in enumerate(selected_examples): + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + + # Convert word lists to strings + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + logging.info(f"Example {i+1} (ID: {cut_id}):") + logging.info(f" REF: {ref_text}") + logging.info(f" HYP: {hyp_text}") + + # Simple word error analysis + ref_word_list = ref_text.split() + hyp_word_list = hyp_text.split() + + if ref_word_list == hyp_word_list: + logging.info(f" --> ✅ PERFECT MATCH ({len(ref_word_list)} words, WER: 0.0%)") + total_sample_wer += 0.0 + valid_samples += 1 + else: + # Basic error analysis + ref_len = len(ref_word_list) + hyp_len = len(hyp_word_list) + + # Calculate simple WER for this utterance + import difflib + matcher = difflib.SequenceMatcher(None, ref_word_list, hyp_word_list) + word_errors = ref_len + hyp_len - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / ref_len * 100) if ref_len > 0 else 0 + total_sample_wer += utt_wer + valid_samples += 1 + + # Find common words for basic analysis + ref_set = set(ref_word_list) + hyp_set = set(hyp_word_list) + missing_words = ref_set - hyp_set + extra_words = hyp_set - ref_set + + error_info = f"WER: {utt_wer:.1f}%, REF: {ref_len} words, HYP: {hyp_len} words" + if missing_words and len(missing_words) <= 3: + error_info += f", Missing: {list(missing_words)}" + elif missing_words: + error_info += f", Missing: {len(missing_words)} words" + + if extra_words and len(extra_words) <= 3: + error_info += f", Extra: {list(extra_words)}" + elif extra_words: + error_info += f", Extra: {len(extra_words)} words" + + logging.info(f" --> ❌ ERRORS ({error_info})") + logging.info("") + + # Log average WER for the examples + if valid_samples > 0: + avg_example_wer = total_sample_wer / valid_samples + logging.info(f"Average WER for these {valid_samples} examples: {avg_example_wer:.2f}%") + + logging.info("=" * 80) + + +def log_validation_examples_to_tensorboard(results_dict, tb_writer, step, max_examples=5): + """ + Log validation examples to TensorBoard as text. + + Args: + results_dict: Dictionary containing decoding results + tb_writer: TensorBoard writer + step: Current training step + max_examples: Maximum number of examples to log + """ + if not results_dict or tb_writer is None: + return + + # Get the first method's results + first_method = list(results_dict.keys())[0] + results = results_dict[first_method] + + if not results: + return + + # Select diverse examples + selected_examples = [] + perfect_matches = [] + error_cases = [] + + for result in results: + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + if ref_text.split() == hyp_text.split(): + perfect_matches.append(result) + else: + error_cases.append(result) + + # Mix error cases and perfect matches + selected_examples = error_cases[:max_examples-1] + perfect_matches[:1] + if len(selected_examples) < max_examples: + selected_examples.extend(results[:max_examples - len(selected_examples)]) + + selected_examples = selected_examples[:max_examples] + + # Create text to log to TensorBoard + tb_text = "## Validation Examples\n\n" + + total_wer = 0 + valid_count = 0 + + for i, result in enumerate(selected_examples): + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + tb_text += f"**Example {i+1} (ID: {cut_id})**\n\n" + tb_text += f"- **REF:** {ref_text}\n" + tb_text += f"- **HYP:** {hyp_text}\n" + + # Calculate simple WER for this utterance + ref_word_list = ref_text.split() + hyp_word_list = hyp_text.split() + + if ref_word_list == hyp_word_list: + tb_text += f"- **Result:** ✅ PERFECT MATCH ({len(ref_word_list)} words, WER: 0.0%)\n\n" + total_wer += 0.0 + valid_count += 1 + else: + import difflib + matcher = difflib.SequenceMatcher(None, ref_word_list, hyp_word_list) + word_errors = len(ref_word_list) + len(hyp_word_list) - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / len(ref_word_list) * 100) if len(ref_word_list) > 0 else 0 + tb_text += f"- **Result:** ❌ WER: {utt_wer:.1f}% (REF: {len(ref_word_list)} words, HYP: {len(hyp_word_list)} words)\n\n" + total_wer += utt_wer + valid_count += 1 + + # Add summary statistics + if valid_count > 0: + avg_wer = total_wer / valid_count + tb_text += f"**Summary:** Average WER for {valid_count} examples: {avg_wer:.2f}%\n\n" + + # Log to TensorBoard + tb_writer.add_text("Validation/Examples", tb_text, step) + + def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) - + args.bpe_dir = Path(args.bpe_dir) world_size = args.world_size assert world_size >= 1 if world_size > 1: @@ -811,8 +1410,6 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc2/subsampling.py b/egs/librispeech/ASR/conformer_ctc2/subsampling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc3/model.py b/egs/librispeech/ASR/conformer_ctc3/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_ctc_sd/.vscode/settings.json b/egs/librispeech/ASR/conformer_ctc_sd/.vscode/settings.json new file mode 100644 index 000000000..dd4530b3c --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/.vscode/settings.json @@ -0,0 +1,22 @@ +{ + "workbench.colorCustomizations": { + "activityBar.activeBackground": "#92a87d", + "activityBar.background": "#92a87d", + "activityBar.foreground": "#15202b", + "activityBar.inactiveForeground": "#15202b99", + "activityBarBadge.background": "#596f86", + "activityBarBadge.foreground": "#e7e7e7", + "commandCenter.border": "#15202b99", + "sash.hoverBorder": "#92a87d", + "statusBar.background": "#799161", + "statusBar.foreground": "#15202b", + "statusBarItem.hoverBackground": "#5f724d", + "statusBarItem.remoteBackground": "#799161", + "statusBarItem.remoteForeground": "#15202b", + "titleBar.activeBackground": "#799161", + "titleBar.activeForeground": "#15202b", + "titleBar.inactiveBackground": "#79916199", + "titleBar.inactiveForeground": "#15202b99" + }, + "peacock.remoteColor": "#799161" +} \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc_sd/README.md b/egs/librispeech/ASR/conformer_ctc_sd/README.md new file mode 100755 index 000000000..1bccccc73 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/README.md @@ -0,0 +1,75 @@ +## Introduction + +Please visit + +for how to run this recipe. + +## How to compute framewise alignment information + +### Step 1: Train a model + +Please use `conformer_ctc/train.py` to train a model. +See +for how to do it. + +### Step 2: Compute framewise alignment + +Run + +``` +# Choose a checkpoint and determine the number of checkpoints to average +epoch=30 +avg=15 +./conformer_ctc/ali.py \ + --epoch $epoch \ + --avg $avg \ + --max-duration 500 \ + --bucketing-sampler 0 \ + --full-libri 1 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --ali-dir data/ali_500 +``` +and you will get four files inside the folder `data/ali_500`: + +``` +$ ls -lh data/ali_500 +total 546M +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt +-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt +-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt +``` + +**Note**: It can take more than 3 hours to compute the alignment +for the training dataset, which contains 960 * 3 = 2880 hours of data. + +**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those +in `conformer_ctc/train.py`. + +**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`. +Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. + +### Step 3: Check your extracted alignments + +There is a file `test_ali.py` in `icefall/test` that can be used to test your +alignments. It uses pre-computed alignments to modify a randomly generated +`nnet_output` and it checks that we can decode the correct transcripts +from the resulting `nnet_output`. + +You should get something like the following if you run that script: + +``` +$ ./test/test_ali.py +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO DWELL THAT THIS PASSAGE THAT LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO GAMEWELL THAT THIS PASSAGE WAY LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +``` + +### Step 4: Use your alignments in training + +Please refer to `conformer_mmi/train.py` for usage. Some useful +functions are: + +- `load_alignments()`, it loads alignment saved by `conformer_ctc/ali.py` +- `convert_alignments_to_tensor()`, it converts alignments to PyTorch tensors +- `lookup_alignments()`, it returns the alignments of utterances by giving the cut ID of the utterances. diff --git a/egs/aidatatang_200zh/ASR/local/__init__.py b/egs/librispeech/ASR/conformer_ctc_sd/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from egs/aidatatang_200zh/ASR/local/__init__.py rename to egs/librispeech/ASR/conformer_ctc_sd/__init__.py diff --git a/egs/mgb2/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc_sd/ali.py similarity index 98% rename from egs/mgb2/ASR/conformer_ctc/ali.py rename to egs/librispeech/ASR/conformer_ctc_sd/ali.py index aea962dcd..42e14abac 100755 --- a/egs/mgb2/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/ali.py @@ -96,14 +96,14 @@ def get_parser(): - labels_xxx.h5 - aux_labels_xxx.h5 - - cuts_xxx.json.gz + - librispeech_cuts_xxx.jsonl.gz where xxx is the value of `--dataset`. For instance, if `--dataset` is `train-clean-100`, it will contain 3 files: - `labels_train-clean-100.h5` - `aux_labels_train-clean-100.h5` - - `cuts_train-clean-100.json.gz` + - `librispeech_cuts_train-clean-100.jsonl.gz` Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise alignment. The difference is that labels_xxx.h5 contains repeats. @@ -285,7 +285,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/aishell/ASR/conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc_sd/asr_datamodule.py similarity index 100% rename from egs/aishell/ASR/conformer_ctc/asr_datamodule.py rename to egs/librispeech/ASR/conformer_ctc_sd/asr_datamodule.py diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/conformer_ctc_sd/conformer.py old mode 100644 new mode 100755 similarity index 72% rename from egs/aishell/ASR/transducer_stateless/conformer.py rename to egs/librispeech/ASR/conformer_ctc_sd/conformer.py index 78424aea2..8c1529500 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/conformer.py @@ -15,60 +15,67 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import copy import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) from torch import Tensor, nn -from transformer import Transformer from icefall.utils import make_pad_mask -class Conformer(Transformer): +class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features - output_dim (int): Number of output dimension subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension + d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate + layer_dropout (float): layer-dropout rate. cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. vgg_frontend (bool): whether to use vgg frontend. """ def __init__( self, num_features: int, - output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, dropout: float = 0.1, + layer_dropout: float = 0.075, cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, + middle_output_layer: int = None, # 0-based layer index ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - output_dim=output_dim, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - ) + super(Conformer, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -77,21 +84,25 @@ class Conformer(Transformer): nhead, dim_feedforward, dropout, + layer_dropout, cnn_module_kernel, - normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity + + output_layers = [] + if middle_output_layer is not None: + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers + output_layers.append(middle_output_layer) + + # The last layer is always needed. + output_layers.append(num_encoder_layers - 1) + + self.encoder = ConformerEncoder( + encoder_layer, num_encoder_layers, output_layers=output_layers + ) def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Args: x: @@ -99,30 +110,35 @@ class Conformer(Transformer): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. + - embeddings: its shape is (batch_size, output_seq_len, d_model) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. """ x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 + assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + layer_results = self.encoder( + x, pos_emb, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) - if self.normalize_before: - x = self.after_norm(x) - - logits = self.encoder_output_layer(x) - logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return logits, lengths + return layer_results, lengths class ConformerEncoderLayer(nn.Module): @@ -136,7 +152,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -151,47 +166,51 @@ class ConformerEncoderLayer(nn.Module): nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + layer_dropout: float = 0.075, cnn_module_kernel: int = 31, - normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + + self.d_model = d_model + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_final = BasicNorm(d_model) - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) self.dropout = nn.Dropout(dropout) - self.normalize_before = normalize_before - def forward( self, src: Tensor, pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, ) -> Tensor: """ Pass the input through the encoder layer. @@ -201,6 +220,8 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. Shape: src: (S, N, E). @@ -209,19 +230,24 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) - if not self.normalize_before: - src = self.norm_ff_macaron(src) + src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) src_att = self.self_attn( src, src, @@ -230,41 +256,30 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) + src = src + self.dropout(src_att) # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout( + src = src + self.dropout( self.conv_module(src, src_key_padding_mask=src_key_padding_mask) ) - if not self.normalize_before: - src = self.norm_conv(src) # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) + src = src + self.dropout(self.feed_forward(src)) - if self.normalize_before: - src = self.norm_final(src) + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig return src -class ConformerEncoder(nn.TransformerEncoder): +class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -275,11 +290,17 @@ class ConformerEncoder(nn.TransformerEncoder): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + self, + encoder_layer: nn.Module, + num_layers: int, + output_layers: List[int], ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) + self.num_layers = num_layers + self.output_layers = output_layers def forward( self, @@ -287,7 +308,8 @@ class ConformerEncoder(nn.TransformerEncoder): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + warmup: float = 1.0, + ) -> List[Tensor]: r"""Pass the input through the encoder layers in turn. Args: @@ -306,18 +328,20 @@ class ConformerEncoder(nn.TransformerEncoder): """ output = src - for mod in self.layers: + layer_results = [] + for i, mod in enumerate(self.layers): output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + warmup=warmup, ) + if i in self.output_layers: + # (T, N, C) --> (N, T, C) + layer_results.append(output.permute(1, 0, 2)) - if self.norm is not None: - output = self.norm(output) - - return output + return layer_results class RelPositionalEncoding(torch.nn.Module): @@ -337,7 +361,6 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -387,7 +410,6 @@ class RelPositionalEncoding(torch.nn.Module): """ self.extend_pe(x) - x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2 @@ -429,25 +451,30 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) def forward( self, @@ -507,11 +534,11 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.weight, - self.in_proj.bias, + self.in_proj.get_weight(), + self.in_proj.get_bias(), self.dropout, - self.out_proj.weight, - self.out_proj.bias, + self.out_proj.get_weight(), + self.out_proj.get_bias(), training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -615,6 +642,7 @@ class RelPositionMultiheadAttention(nn.Module): assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 if torch.equal(query, key) and torch.equal(key, value): @@ -633,6 +661,7 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -710,7 +739,7 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) @@ -731,11 +760,11 @@ class RelPositionMultiheadAttention(nn.Module): p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose( + q_with_bias_u = (q + self._pos_bias_u()).transpose( 1, 2 ) # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose( + q_with_bias_v = (q + self._pos_bias_v()).transpose( 1, 2 ) # (batch, head, time1, d_k) @@ -751,9 +780,7 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) @@ -820,7 +847,7 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = nn.Conv1d( + self.pointwise_conv1 = ScaledConv1d( channels, 2 * channels, kernel_size=1, @@ -828,7 +855,25 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - self.depthwise_conv = nn.Conv1d( + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + self.depthwise_conv = ScaledConv1d( channels, channels, kernel_size, @@ -837,16 +882,22 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.LayerNorm(channels) - self.pointwise_conv2 = nn.Conv1d( + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, + initial_scale=0.25, ) - self.activation = Swish() def forward( self, @@ -868,17 +919,16 @@ class ConvolutionModule(nn.Module): # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv if src_key_padding_mask is not None: x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) - # x is (batch, channels, time) - x = x.permute(0, 2, 1) - x = self.norm(x) - x = x.permute(0, 2, 1) + x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) @@ -886,13 +936,108 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class Swish(torch.nn.Module): - """Construct an Swish object.""" +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x -def identity(x): - return x +if __name__ == "__main__": + feature_dim = 50 + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) diff --git a/egs/librispeech/ASR/conformer_ctc_sd/conformer_ctc.py b/egs/librispeech/ASR/conformer_ctc_sd/conformer_ctc.py new file mode 100644 index 000000000..96ce78541 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/conformer_ctc.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Conformer CTC model with support for self-distillation on encoder outputs and attention maps. +""" + +import logging +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import Conformer +from conformer_with_attention import ConformerWithAttention + +# from icefall.utils import add_sos_eos, is_jit_tracing, is_jit_scripting + + +class ConformerCTC(nn.Module): + """Conformer CTC model with self-distillation support. + + This model extends the basic Conformer encoder to support extraction of + intermediate layer outputs and attention maps for self-distillation. + + Args: + num_features: Number of input features (e.g., 80 for Fbank) + num_classes: Number of output classes (vocabulary size) + subsampling_factor: Subsampling factor of encoder + d_model: Model dimension + nhead: Number of attention heads + dim_feedforward: Feedforward dimension + num_encoder_layers: Number of encoder layers + dropout: Dropout rate + cnn_module_kernel: Convolution module kernel size + distill_layers: List of layer indices for distillation (0-based) + knowledge_type: Type of knowledge to extract ('encoder-output' or 'attention-map') + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + distill_layers: Optional[List[int]] = None, + knowledge_type: str = "encoder-output", + ) -> None: + super().__init__() + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + self.d_model = d_model + self.distill_layers = distill_layers or [] + self.knowledge_type = knowledge_type + + # Determine which layers need to output intermediate results + output_layers = [] + if distill_layers: + output_layers.extend(distill_layers) + + # Create conformer encoder with support for intermediate outputs and attention maps + if knowledge_type == "attention-map": + self.encoder = ConformerWithAttention( + num_features=num_features, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + cnn_module_kernel=cnn_module_kernel, + attention_layers=distill_layers, + ) + else: + self.encoder = Conformer( + num_features=num_features, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + cnn_module_kernel=cnn_module_kernel, + ) + + # Modify the encoder to output intermediate layers + if distill_layers and knowledge_type != "attention-map": + # Add the last layer to always be included + output_layers = list(set(distill_layers + [num_encoder_layers - 1])) + output_layers.sort() + self.encoder.encoder.output_layers = output_layers + + # CTC output projection + self.ctc_output = nn.Linear(d_model, num_classes) + + def _modify_encoder_for_attention_maps(self): + """Modify the encoder layers to extract attention maps using forward hooks.""" + # Store attention weights during forward pass + self._attention_storage = {} + self._hooks = [] + + # Register forward hooks instead of modifying methods directly + for layer_idx in self.distill_layers: + if layer_idx < len(self.encoder.encoder.layers): + layer = self.encoder.encoder.layers[layer_idx] + + # Create hook function that captures attention weights + def create_attention_hook(idx): + def attention_hook(module, input, output): + # This hook will be called after the layer's forward pass + # We need to modify the layer to store attention weights + pass + return attention_hook + + # Register the hook + hook = layer.register_forward_hook(create_attention_hook(layer_idx)) + self._hooks.append(hook) + + # Alternative: Monkey patch the self_attn module specifically + for layer_idx in self.distill_layers: + if layer_idx < len(self.encoder.encoder.layers): + layer = self.encoder.encoder.layers[layer_idx] + original_self_attn = layer.self_attn + + def create_patched_attention(orig_attn, idx): + class PatchedAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.original_attn = orig_attn + self.layer_idx = idx + + def forward(self, query, key, value, pos_emb=None, attn_mask=None, + key_padding_mask=None, need_weights=False): + # Always request attention weights for distillation layers + output, attn_weights = self.original_attn( + query, key, value, pos_emb=pos_emb, attn_mask=attn_mask, + key_padding_mask=key_padding_mask, need_weights=True + ) + + # Store attention weights in parent module + if attn_weights is not None and hasattr(self, '_parent_storage'): + self._parent_storage[self.layer_idx] = attn_weights + + # Return in expected format + if need_weights: + return output, attn_weights + else: + return output + + patched = PatchedAttention() + patched._parent_storage = self._attention_storage + return patched + + # Replace the self_attn module + layer.self_attn = create_patched_attention(original_self_attn, layer_idx) + + def cleanup_hooks(self): + """Clean up forward hooks to prevent memory leaks.""" + if hasattr(self, '_hooks'): + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + def forward( + self, + x: torch.Tensor, + supervisions: Optional[Dict] = None + ) -> Dict[str, torch.Tensor]: + """Forward pass with support for distillation knowledge extraction. + + Args: + x: Input tensor of shape (N, T, num_features) + supervisions: Supervision information (optional) + + Returns: + Dictionary containing: + - ctc_output: CTC output logits (N, T', num_classes) + - encoder_out: Final encoder output (N, T', d_model) + - encoder_out_lens: Sequence lengths after subsampling + - distill_outputs: Intermediate layer outputs for distillation + - attention_maps: Attention maps if knowledge_type is 'attention-map' + """ + # Get sequence lengths + x_lens = x.new_zeros(x.size(0)).long() + x.size(1) + + if self.knowledge_type == "attention-map": + # Use ConformerWithAttention + layer_outputs, output_lens, attention_maps = self.encoder(x, x_lens) + else: + # Use regular Conformer + layer_outputs, output_lens = self.encoder(x, x_lens) + attention_maps = {} + + # The last output is the final encoder output + encoder_out = layer_outputs[-1] # (N, T', d_model) + + # CTC output projection + ctc_output = self.ctc_output(encoder_out) # (N, T', num_classes) + + # Prepare return dictionary + result = { + 'ctc_output': ctc_output, + 'encoder_out': encoder_out, + 'encoder_out_lens': output_lens, + } + + # Extract distillation knowledge based on type + if self.distill_layers: + if self.knowledge_type == "encoder-output": + # Extract encoder outputs from specified layers + distill_outputs = {} + for i, layer_idx in enumerate(self.distill_layers): + if i < len(layer_outputs): + distill_outputs[layer_idx] = layer_outputs[i] + result['distill_outputs'] = distill_outputs + + # For backward compatibility, also provide single distill_hidden + if len(self.distill_layers) == 1: + result['distill_hidden'] = layer_outputs[0] + + elif self.knowledge_type == "attention-map": + # Return attention maps from ConformerWithAttention + result['attention_maps'] = attention_maps + + return result + + +def compute_distillation_loss( + teacher_knowledge: torch.Tensor, + student_knowledge: torch.Tensor, + knowledge_lens: torch.Tensor, + loss_type: str = "mse", + knowledge_type: str = "encoder-output", + temperature: float = 1.0, +) -> torch.Tensor: + """Compute distillation loss between teacher and student knowledge. + + Args: + teacher_knowledge: Teacher knowledge tensor + student_knowledge: Student knowledge tensor + knowledge_lens: Sequence lengths for masking + loss_type: Type of loss ('mse', 'cosine' for encoder outputs; 'kl' for attention maps) + knowledge_type: Type of knowledge ('encoder-output' or 'attention-map') + temperature: Temperature for softmax (used with KL divergence) + + Returns: + Computed distillation loss + """ + if knowledge_type == "encoder-output": + # Handle encoder output distillation + if loss_type == "mse": + return _compute_mse_loss(teacher_knowledge, student_knowledge, knowledge_lens) + elif loss_type == "cosine": + return _compute_cosine_loss(teacher_knowledge, student_knowledge, knowledge_lens) + else: + raise ValueError(f"Unsupported loss type for encoder outputs: {loss_type}") + + elif knowledge_type == "attention-map": + # Handle attention map distillation + if loss_type == "kl": + return _compute_kl_divergence_loss(teacher_knowledge, student_knowledge, knowledge_lens, temperature) + else: + raise ValueError(f"Unsupported loss type for attention maps: {loss_type}") + else: + raise ValueError(f"Unsupported knowledge type: {knowledge_type}") + + +def _compute_mse_loss( + teacher_hidden: torch.Tensor, + student_hidden: torch.Tensor, + hidden_lens: torch.Tensor +) -> torch.Tensor: + """Compute MSE loss between teacher and student hidden states.""" + # teacher_hidden, student_hidden: (N, T, d_model) + # hidden_lens: (N,) + + batch_size, max_len, _ = teacher_hidden.shape + + # Create mask for valid positions + mask = torch.arange(max_len, device=hidden_lens.device)[None, :] < hidden_lens[:, None] + mask = mask.float() # (N, T) + + # Compute MSE loss element-wise + mse_loss = F.mse_loss(student_hidden, teacher_hidden, reduction='none') # (N, T, d_model) + mse_loss = mse_loss.mean(dim=-1) # (N, T) + + # Apply mask and compute mean + masked_loss = mse_loss * mask + total_loss = masked_loss.sum() + total_tokens = mask.sum() + + return total_loss / (total_tokens + 1e-8) + + +def _compute_cosine_loss( + teacher_hidden: torch.Tensor, + student_hidden: torch.Tensor, + hidden_lens: torch.Tensor +) -> torch.Tensor: + """Compute cosine similarity loss between teacher and student hidden states.""" + # teacher_hidden, student_hidden: (N, T, d_model) + # hidden_lens: (N,) + + batch_size, max_len, _ = teacher_hidden.shape + + # Create mask for valid positions + mask = torch.arange(max_len, device=hidden_lens.device)[None, :] < hidden_lens[:, None] + mask = mask.float() # (N, T) + + # Compute cosine similarity + cosine_sim = F.cosine_similarity(teacher_hidden, student_hidden, dim=-1) # (N, T) + + # Convert to loss (1 - cosine_similarity) + cosine_loss = 1.0 - cosine_sim # (N, T) + + # Apply mask and compute mean + masked_loss = cosine_loss * mask + total_loss = masked_loss.sum() + total_tokens = mask.sum() + + return total_loss / (total_tokens + 1e-8) + + +def _compute_kl_divergence_loss( + teacher_attention: torch.Tensor, + student_attention: torch.Tensor, + attention_lens: torch.Tensor, + temperature: float = 1.0, +) -> torch.Tensor: + """Compute KL divergence loss between teacher and student attention maps. + + Args: + teacher_attention: Teacher attention weights (N, H, T, T) or (N, T, T) + student_attention: Student attention weights (N, H, T, T) or (N, T, T) + attention_lens: Sequence lengths for masking (N,) + temperature: Temperature for softmax smoothing + + Returns: + KL divergence loss + """ + # Handle different attention map shapes + if teacher_attention.dim() == 4: + # (N, H, T, T) -> average over heads to get (N, T, T) + teacher_attention = teacher_attention.mean(dim=1) + student_attention = student_attention.mean(dim=1) + + batch_size, seq_len, _ = teacher_attention.shape + + # Create attention mask + mask = torch.arange(seq_len, device=attention_lens.device)[None, :] < attention_lens[:, None] + mask = mask.float() # (N, T) + + # Create 2D mask for attention matrix + mask_2d = mask.unsqueeze(-1) * mask.unsqueeze(-2) # (N, T, T) + + # Apply temperature and compute log probabilities + teacher_log_probs = F.log_softmax(teacher_attention / temperature, dim=-1) + student_log_probs = F.log_softmax(student_attention / temperature, dim=-1) + + # Convert student to probabilities for KL divergence + student_probs = F.softmax(student_attention / temperature, dim=-1) + + # Compute KL divergence: KL(student || teacher) = sum(student * (log(student) - log(teacher))) + kl_div = student_probs * (student_log_probs - teacher_log_probs) # (N, T, T) + kl_div = kl_div.sum(dim=-1) # (N, T) + + # Apply mask and compute mean + masked_kl = kl_div * mask + total_loss = masked_kl.sum() + total_tokens = mask.sum() + + return total_loss / (total_tokens + 1e-8) + + +def compute_multi_layer_distillation_loss( + teacher_knowledge: Dict[int, torch.Tensor], + student_knowledge: Dict[int, torch.Tensor], + knowledge_lens: torch.Tensor, + layer_indices: List[int], + loss_type: str = "mse", + knowledge_type: str = "encoder-output", + aggregation: str = "layer_avg", + temperature: float = 1.0, +) -> torch.Tensor: + """Compute multi-layer distillation loss with specified aggregation strategy. + + Args: + teacher_knowledge: Dictionary mapping layer indices to teacher knowledge tensors + student_knowledge: Dictionary mapping layer indices to student knowledge tensors + knowledge_lens: Sequence lengths for masking + layer_indices: List of layer indices to compute loss for + loss_type: Type of loss computation + knowledge_type: Type of knowledge being distilled + aggregation: Aggregation strategy ('layer_avg' or 'output_avg') + temperature: Temperature for softmax (attention maps) + + Returns: + Aggregated distillation loss + """ + if aggregation == "layer_avg": + # Compute loss for each layer and average them + total_loss = torch.tensor(0.0, device=knowledge_lens.device) + valid_layers = 0 + + for layer_idx in layer_indices: + if layer_idx in teacher_knowledge and layer_idx in student_knowledge: + layer_loss = compute_distillation_loss( + teacher_knowledge[layer_idx], + student_knowledge[layer_idx], + knowledge_lens, + loss_type, + knowledge_type, + temperature, + ) + total_loss += layer_loss + valid_layers += 1 + + if valid_layers > 0: + return total_loss / valid_layers + else: + return torch.tensor(0.0, device=knowledge_lens.device) + + elif aggregation == "output_avg": + # Average the layer outputs first, then compute a single loss + if knowledge_type == "encoder-output": + # Stack and average encoder outputs + teacher_outputs = [] + student_outputs = [] + + for layer_idx in layer_indices: + if layer_idx in teacher_knowledge and layer_idx in student_knowledge: + teacher_outputs.append(teacher_knowledge[layer_idx]) + student_outputs.append(student_knowledge[layer_idx]) + + if not teacher_outputs: + return torch.tensor(0.0, device=knowledge_lens.device) + + # Average the outputs + avg_teacher = torch.stack(teacher_outputs).mean(dim=0) + avg_student = torch.stack(student_outputs).mean(dim=0) + + return compute_distillation_loss( + avg_teacher, avg_student, knowledge_lens, loss_type, knowledge_type, temperature + ) + + elif knowledge_type == "attention-map": + # Average attention maps and compute KL divergence + teacher_attentions = [] + student_attentions = [] + + for layer_idx in layer_indices: + if layer_idx in teacher_knowledge and layer_idx in student_knowledge: + teacher_attentions.append(teacher_knowledge[layer_idx]) + student_attentions.append(student_knowledge[layer_idx]) + + if not teacher_attentions: + return torch.tensor(0.0, device=knowledge_lens.device) + + # Average the attention maps + avg_teacher_attention = torch.stack(teacher_attentions).mean(dim=0) + avg_student_attention = torch.stack(student_attentions).mean(dim=0) + + return compute_distillation_loss( + avg_teacher_attention, avg_student_attention, knowledge_lens, + loss_type, knowledge_type, temperature + ) + else: + raise ValueError(f"Unsupported aggregation strategy: {aggregation}") diff --git a/egs/librispeech/ASR/conformer_ctc_sd/conformer_with_attention.py b/egs/librispeech/ASR/conformer_ctc_sd/conformer_with_attention.py new file mode 100644 index 000000000..b88aa392c --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/conformer_with_attention.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 + +""" +Conformer with attention map extraction support for multi-GPU training. +""" + +import torch +import torch.nn as nn +from typing import Dict, List, Optional, Tuple +from conformer import Conformer, ConformerEncoder, ConformerEncoderLayer + + +class ConformerEncoderLayerWithAttention(ConformerEncoderLayer): + """ConformerEncoderLayer that can optionally return attention weights.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.return_attention = False + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + return_attention: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass with optional attention weight return.""" + + # Store original parameters + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # multi-headed self-attention module + src_att, attn_weights = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + need_weights=return_attention or self.return_attention, + ) + + src = src + self.dropout(src_att) + + # convolution module + src = src + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src, attn_weights if (return_attention or self.return_attention) else None + + +class ConformerEncoderWithAttention(ConformerEncoder): + """ConformerEncoder that can extract attention maps from specified layers.""" + + def __init__(self, *args, attention_layers: Optional[List[int]] = None, **kwargs): + super().__init__(*args, **kwargs) + self.attention_layers = attention_layers or [] + + # Replace layers with attention-capable versions + new_layers = [] + for i, layer in enumerate(self.layers): + # Create new layer with same parameters + new_layer = ConformerEncoderLayerWithAttention( + d_model=layer.d_model, + nhead=layer.self_attn.num_heads, + dim_feedforward=layer.feed_forward.w_1.in_features, + dropout=layer.dropout.p, + activation=layer.feed_forward.activation, + layer_dropout=layer.layer_dropout, + cnn_module_kernel=layer.conv_module.pointwise_conv2.out_channels // layer.conv_module.pointwise_conv1.out_channels, + ) + + # Copy weights + new_layer.load_state_dict(layer.state_dict()) + + # Enable attention return for specified layers + if i in self.attention_layers: + new_layer.return_attention = True + + new_layers.append(new_layer) + + self.layers = nn.ModuleList(new_layers) + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[List[torch.Tensor], torch.Tensor, Dict[int, torch.Tensor]]: + """Forward pass returning layer outputs and attention maps.""" + + output = src + layer_outputs = [] + attention_maps = {} + + for i, layer in enumerate(self.layers): + output, attn_weights = layer( + output, + pos_emb, + src_mask=src_mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + return_attention=i in self.attention_layers, + ) + + # Store layer output if needed + if i in self.output_layers: + layer_outputs.append(output) + + # Store attention weights if available + if attn_weights is not None: + attention_maps[i] = attn_weights + + # Ensure we always have the final layer output + if len(self.output_layers) == 0 or (len(self.layers) - 1) not in self.output_layers: + layer_outputs.append(output) + + # Calculate output lengths (assuming subsampling was applied earlier) + output_lens = src_key_padding_mask.size(1) - src_key_padding_mask.sum(dim=1) + + return layer_outputs, output_lens, attention_maps + + +class ConformerWithAttention(Conformer): + """Conformer with built-in attention map extraction support.""" + + def __init__(self, *args, attention_layers: Optional[List[int]] = None, **kwargs): + # Initialize parent without calling its __init__ to avoid double initialization + nn.Module.__init__(self) + + # Store parameters + self.num_features = kwargs.get('num_features') + self.subsampling_factor = kwargs.get('subsampling_factor', 4) + self.d_model = kwargs.get('d_model', 256) + self.attention_layers = attention_layers or [] + + # Create subsampling layer + from subsampling import Conv2dSubsampling + self.subsampling = Conv2dSubsampling( + in_channels=1, + out_channels=self.d_model, + subsampling_factor=self.subsampling_factor, + ) + + # Create encoder with attention support + self.encoder = ConformerEncoderWithAttention( + *args, + attention_layers=attention_layers, + **kwargs + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[List[torch.Tensor], torch.Tensor, Dict[int, torch.Tensor]]: + """Forward pass returning layer outputs and attention maps.""" + + # Subsampling + x, pos_emb = self.subsampling(x) + + # Create padding mask + max_len = x.size(1) + batch_size = x.size(0) + lengths_after_subsampling = ((x_lens - 1) // self.subsampling_factor) + 1 + + padding_mask = torch.arange(max_len, device=x.device)[None, :] >= lengths_after_subsampling[:, None] + + # Encoder forward pass + layer_outputs, output_lens, attention_maps = self.encoder( + x, pos_emb, src_key_padding_mask=padding_mask + ) + + return layer_outputs, output_lens, attention_maps diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc_sd/decode.py similarity index 79% rename from egs/swbd/ASR/conformer_ctc/decode.py rename to egs/librispeech/ASR/conformer_ctc_sd/decode.py index 52e501ae1..047b1fa21 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/decode.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) -# Modified by Zengrui Jin for the SwitchBoard corpus # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -27,9 +26,8 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import SwitchBoardAsrDataModule +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer -from sclite_scoring import asr_text_post_processing from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint @@ -44,7 +42,6 @@ from icefall.decode import ( rescore_with_whole_lattice, ) from icefall.env import get_env_info -from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, @@ -65,7 +62,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=98, + default=77, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -132,15 +129,15 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="conformer_ctc/exp", + default="conformer_ctc/exp/models", help="The experiment dir", ) parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe_500", - help="The lang dir", + default="data/lang_bpe_5000", + help="The lang dir (using BPE)", ) parser.add_argument( @@ -219,9 +216,9 @@ def get_params() -> AttributeDict: "vgg_frontend": False, "use_feat_batchnorm": True, "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, + "nhead": 4, + "attention_dim": 256, + "num_decoder_layers": 0, # parameters for decoding "search_beam": 20, "output_beam": 8, @@ -234,17 +231,6 @@ def get_params() -> AttributeDict: return params -def post_processing( - results: List[Tuple[str, List[str], List[str]]], -) -> List[Tuple[str, List[str], List[str]]]: - new_results = [] - for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((key, new_ref, new_hyp)) - return new_results - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -307,6 +293,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ + if HLG is not None: device = HLG.device else: @@ -317,10 +304,11 @@ def decode_one_batch( # at entry, feature is (N, T, C) supervisions = batch["supervisions"] - + # Step 1: Model forward pass nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) # nnet_output is (N, T, C) - + + # Step 2: Supervision segments preparation supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -330,6 +318,14 @@ def decode_one_batch( 1, ).to(torch.int32) + + # Ensure supervision segments don't exceed nnet_output length + max_allowed_frames = nnet_output.size(1) + supervision_segments[:, 2] = torch.clamp(supervision_segments[:, 2], max=max_allowed_frames) + + # CRITICAL FIX: k2.DenseFsaVec requires supervision_segments to be on CPU + supervision_segments = supervision_segments.cpu() + if H is None: assert HLG is not None decoding_graph = HLG @@ -337,7 +333,8 @@ def decode_one_batch( assert HLG is None assert bpe_model is not None decoding_graph = H - + + # Step 3: Lattice generation lattice = get_lattice( nnet_output=nnet_output, decoding_graph=decoding_graph, @@ -350,9 +347,11 @@ def decode_one_batch( ) if params.method == "ctc-decoding": + # Step 4: CTC decoding best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs # since we are using H, not HLG here. # @@ -364,6 +363,7 @@ def decode_one_batch( # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] hyps = [s.split() for s in hyps] + key = "ctc-decoding" return {key: hyps} @@ -378,7 +378,7 @@ def decode_one_batch( ref_texts=supervisions["text"], word_table=word_table, nbest_scale=params.nbest_scale, - oov="", + oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -536,9 +536,17 @@ def decode_dataset( num_batches = "?" results = defaultdict(list) + + logging.info(f"Starting decode with {num_batches} batches") + for batch_idx, batch in enumerate(dl): + + logging.info(f"Processing batch {batch_idx}/{num_batches}") + texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + logging.info(f"Batch {batch_idx}: {len(texts)} cuts, cut_ids: {cut_ids[:3]}...") hyps_dict = decode_one_batch( params=params, @@ -549,11 +557,11 @@ def decode_dataset( bpe_model=bpe_model, batch=batch, word_table=word_table, - G=G, sos_id=sos_id, eos_id=eos_id, + G=G, ) - + if hyps_dict is not None: for lm_scale, hyps in hyps_dict.items(): this_batch = [] @@ -563,6 +571,35 @@ def decode_dataset( this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) + + # Log ground truth vs predicted examples for the first method only + if lm_scale == list(hyps_dict.keys())[0]: # Only log for the first decoding method + # Log a few examples from this batch + num_examples = min(3, len(texts)) # Show up to 3 examples per batch + if num_examples > 0: + logging.info(f"=== DECODE EXAMPLES - Batch {batch_idx} ===") + for i in range(num_examples): + cut_id = cut_ids[i] + ref_text = texts[i] + hyp_text = " ".join(hyps[i]) + + logging.info(f"Example {i+1} (ID: {cut_id}):") + logging.info(f" REF: {ref_text}") + logging.info(f" HYP: {hyp_text}") + + # Simple accuracy check + ref_words = ref_text.split() + hyp_words = hyps[i] + if ref_words == hyp_words: + logging.info(f" --> ✅ PERFECT MATCH ({len(ref_words)} words)") + else: + # Calculate simple word error rate for this utterance + import difflib + matcher = difflib.SequenceMatcher(None, ref_words, hyp_words) + word_errors = len(ref_words) + len(hyp_words) - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / len(ref_words) * 100) if len(ref_words) > 0 else 0 + logging.info(f" --> ❌ WER: {utt_wer:.1f}% (REF: {len(ref_words)} words, HYP: {len(hyp_words)} words)") + logging.info("=" * 50) else: assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] @@ -576,10 +613,12 @@ def decode_dataset( num_cuts += len(texts) - if batch_idx % 100 == 0: + if batch_idx % 10 == 0: # Log more frequently for validation batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info(f"[VALIDATION] batch {batch_str}, cuts processed: {num_cuts}, " + f"cuts in this batch: {len(texts)}") + + logging.info(f"Completed decode_dataset with {num_cuts} total cuts processed") return results @@ -593,63 +632,55 @@ def save_results( enable_log = False else: enable_log = True - if test_set_name == "test-eval2000": - subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} - elif test_set_name == "test-rt03": - subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} - else: - raise NotImplementedError(f"No implementation for testset {test_set_name}") - for subset, prefix in subsets.items(): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" - results = post_processing(results) - results = ( - sorted(list(filter(lambda x: x[0].startswith(prefix), results))) - if subset != "avg" - else sorted(results) + + # Create results directory if it doesn't exist + results_dir = params.exp_dir / "results" + results_dir.mkdir(exist_ok=True) + + test_set_wers = dict() + for key, results in results_dict.items(): + # Save transcripts in results folder + recog_path = results_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs - also save in results folder + errs_filename = results_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log ) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") + test_set_wers[key] = wer - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{subset}-{key}", - results, - enable_log=enable_log, - sclite_mode=True, - ) - test_set_wers[key] = wer + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}-{}, WER of different settings are:\n".format( - test_set_name, subset - ) - note = "\tbest for {}".format(test_set_name) + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + # Save WER summary in results folder + errs_info = results_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + # Return WER results for external use + return dict(test_set_wers) @torch.no_grad() def main(): parser = get_parser() - SwitchBoardAsrDataModule.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -662,9 +693,11 @@ def main(): logging.info("Decoding started") logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # For BPE mode: read vocab size from tokens.txt + tokens_file = params.lang_dir / "tokens.txt" + with open(tokens_file, 'r', encoding='utf-8') as f: + num_classes = len(f.readlines()) + max_token_id = num_classes - 1 device = torch.device("cpu") if torch.cuda.is_available(): @@ -685,6 +718,16 @@ def main(): params.sos_id = sos_id params.eos_id = eos_id + # Create BPE word table from tokens.txt + word_table = {} + with open(tokens_file, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + parts = line.strip().split() + if len(parts) >= 2: + token, idx = parts[0], parts[1] + word_table[int(idx)] = token + if params.method == "ctc-decoding": HLG = None H = k2.ctc_topo( @@ -715,7 +758,8 @@ def main(): logging.info("Loading G_4_gram.fst.txt") logging.warning("It may take 8 minutes.") with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] + # For BPE mode: use a default disambig ID (assuming #0 maps to ID 0) + first_word_disambig_id = 0 # This should be adjusted based on your BPE vocab G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so @@ -808,22 +852,16 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - switchboard = SwitchBoardAsrDataModule(args) + librispeech = LibriSpeechAsrDataModule(args) - test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions( - keep_all_channels=True - ) - # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( - # keep_all_channels=True - # ) + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() - test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) - # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) - # test_sets = ["test-eval2000", "test-rt03"] - # test_dl = [test_eval2000_dl, test_rt03_dl] - test_sets = ["test-eval2000"] - test_dl = [test_eval2000_dl] + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( @@ -834,7 +872,7 @@ def main(): HLG=HLG, H=H, bpe_model=bpe_model, - word_table=lexicon.word_table, + word_table=word_table, G=G, sos_id=sos_id, eos_id=eos_id, diff --git a/egs/librispeech/ASR/conformer_ctc_sd/detailed_analysis.py b/egs/librispeech/ASR/conformer_ctc_sd/detailed_analysis.py new file mode 100644 index 000000000..78b91d1a9 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/detailed_analysis.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +""" +Detailed test to understand attention map size changes and encoder output structure +""" + +import torch +import torch.nn as nn +from conformer_ctc import ConformerCTC + +def test_attention_map_sizes(): + """Test attention map sizes at different layers""" + print("=" * 60) + print("ATTENTION MAP SIZE ANALYSIS") + print("=" * 60) + + # Create model + model = ConformerCTC( + num_features=80, + num_classes=500, + d_model=256, + num_encoder_layers=6, + nhead=4, + distill_layers=[0, 1, 2, 3, 4, 5], # All layers + knowledge_type='attention-map', + ) + model.eval() + + # Create test input with different sequence lengths + batch_size = 2 + seq_lens = [100, 80] # Different lengths to see padding effect + max_len = max(seq_lens) + + # Create input + x = torch.randn(batch_size, max_len, 80) + x_lens = torch.tensor(seq_lens) + + # Create targets + y = torch.randint(0, 500, (batch_size, 50)) + y_lens = torch.tensor([50, 40]) + + print(f"Input shape: {x.shape}") + print(f"Input lengths: {x_lens}") + print(f"Target shape: {y.shape}") + print(f"Target lengths: {y_lens}") + print() + + # Forward pass + with torch.no_grad(): + # ConformerCTC forward only takes x and supervisions + supervisions = { + 'sequence_idx': torch.arange(batch_size), + 'start_frame': torch.zeros(batch_size), + 'num_frames': x_lens, + } + outputs = model(x, supervisions) + + print("Attention maps from different layers:") + for i, attn_map in enumerate(outputs['distill_outputs']): + print(f"Layer {model.distill_layers[i]}: {attn_map.shape}") + # Check for NaN or inf + if torch.isnan(attn_map).any(): + print(f" ⚠️ WARNING: NaN detected in layer {i}") + if torch.isinf(attn_map).any(): + print(f" ⚠️ WARNING: Inf detected in layer {i}") + + print() + print("Analysis:") + print("- All attention maps should have same batch_size and num_heads") + print("- Sequence length dimension may vary due to subsampling in conformer") + print("- Later layers typically have shorter sequences due to subsampling") + +def test_encoder_output_structure(): + """Test encoder output structure and understand distill_outputs vs distill_hidden""" + print("=" * 60) + print("ENCODER OUTPUT STRUCTURE ANALYSIS") + print("=" * 60) + + # Create model + model = ConformerCTC( + num_features=80, + num_classes=500, + d_model=256, + num_encoder_layers=6, + nhead=4, + distill_layers=[1, 3, 5], # Selected layers + knowledge_type='encoder-output', + ) + model.eval() + + # Create test input + batch_size = 2 + seq_lens = [100, 80] + max_len = max(seq_lens) + + x = torch.randn(batch_size, max_len, 80) + x_lens = torch.tensor(seq_lens) + y = torch.randint(0, 500, (batch_size, 50)) + y_lens = torch.tensor([50, 40]) + + print(f"Input shape: {x.shape}") + print(f"Selected distillation layers: {model.distill_layers}") + print() + + # Forward pass + with torch.no_grad(): + # ConformerCTC forward only takes x and supervisions + supervisions = { + 'sequence_idx': torch.arange(batch_size), + 'start_frame': torch.zeros(batch_size), + 'num_frames': x_lens, + } + outputs = model(x, supervisions) + + print("Encoder outputs from selected layers:") + print(f"distill_outputs type: {type(outputs['distill_outputs'])}") + print(f"distill_outputs length: {len(outputs['distill_outputs'])}") + + for i, enc_output in enumerate(outputs['distill_outputs']): + layer_idx = model.distill_layers[i] + print(f"Layer {layer_idx}: {enc_output.shape}") + print(f" Mean: {enc_output.mean().item():.4f}") + print(f" Std: {enc_output.std().item():.4f}") + + print() + if 'distill_hidden' in outputs: + print("distill_hidden structure:") + print(f"distill_hidden type: {type(outputs['distill_hidden'])}") + if isinstance(outputs['distill_hidden'], (list, tuple)): + print(f"distill_hidden length: {len(outputs['distill_hidden'])}") + for i, hidden in enumerate(outputs['distill_hidden']): + print(f"Hidden {i}: {hidden.shape if torch.is_tensor(hidden) else type(hidden)}") + else: + print(f"distill_hidden shape: {outputs['distill_hidden'].shape}") + else: + print("❌ distill_hidden not found in outputs") + + print() + print("Key differences:") + print("- distill_outputs: Contains the actual hidden states from selected encoder layers") + print("- distill_hidden: May contain additional context or processed versions") + print("- For encoder-output distillation, we primarily use distill_outputs") + +def test_subsampling_effect(): + """Test how subsampling affects sequence lengths through layers""" + print("=" * 60) + print("SUBSAMPLING EFFECT ANALYSIS") + print("=" * 60) + + from conformer import ConformerEncoder + + # Create standalone encoder to track subsampling + encoder = ConformerEncoder( + num_features=80, + d_model=256, + num_layers=6, + nhead=4, + distill_layers=[0, 1, 2, 3, 4, 5], + knowledge_type='attention-map' + ) + encoder.eval() + + # Test input + batch_size = 2 + seq_len = 100 + x = torch.randn(batch_size, seq_len, 80) + x_lens = torch.tensor([seq_len, seq_len]) + + print(f"Original input: {x.shape}") + + with torch.no_grad(): + encoder_out, encoder_out_lens, distill_outputs, distill_hidden = encoder(x, x_lens) + + print(f"Final encoder output: {encoder_out.shape}") + print(f"Final encoder lengths: {encoder_out_lens}") + print() + + if encoder.knowledge_type == 'attention-map': + print("Attention map sizes through layers:") + for i, attn_map in enumerate(distill_outputs): + # Attention map shape: [batch, num_heads, seq_len, seq_len] + seq_len_at_layer = attn_map.shape[-1] + print(f"Layer {encoder.distill_layers[i]}: attention [{attn_map.shape[0]}, {attn_map.shape[1]}, {attn_map.shape[2]}, {attn_map.shape[3]}] -> seq_len = {seq_len_at_layer}") + + print() + print("Observations:") + print("- Sequence length typically reduces due to subsampling in early layers") + print("- This affects attention map sizes (they become smaller)") + print("- All layers after subsampling will have the same reduced sequence length") + +if __name__ == "__main__": + print("🔍 DETAILED DISTILLATION ANALYSIS") + print("=" * 80) + + try: + test_attention_map_sizes() + print("\n" + "="*80 + "\n") + + test_encoder_output_structure() + print("\n" + "="*80 + "\n") + + test_subsampling_effect() + + print("\n✅ All tests completed successfully!") + + except Exception as e: + print(f"❌ Error during testing: {e}") + import traceback + traceback.print_exc() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/__init__.py b/egs/librispeech/ASR/conformer_ctc_sd/ema_teacher.py similarity index 100% rename from egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/__init__.py rename to egs/librispeech/ASR/conformer_ctc_sd/ema_teacher.py diff --git a/egs/aishell/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/conformer_ctc_sd/encoder_interface.py old mode 100644 new mode 100755 similarity index 100% rename from egs/aishell/ASR/transducer_stateless/encoder_interface.py rename to egs/librispeech/ASR/conformer_ctc_sd/encoder_interface.py diff --git a/egs/librispeech/ASR/conformer_ctc_sd/explain_distillation.py b/egs/librispeech/ASR/conformer_ctc_sd/explain_distillation.py new file mode 100644 index 000000000..cee4b7893 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/explain_distillation.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +""" +Simple explanation of attention map sizes and encoder outputs +""" + +import torch +from conformer_ctc import ConformerCTC + +def explain_attention_map_sizes(): + """Explain how attention map sizes are determined""" + print("🔍 ATTENTION MAP SIZE EXPLANATION") + print("=" * 50) + + print("Attention Map Size는 다음과 같이 결정됩니다:") + print() + print("1. 입력 시퀀스 길이:") + print(" - 원본 오디오 프레임 수에 따라 결정") + print(" - 예: 100 프레임 -> 100 길이") + print() + print("2. Subsampling 효과:") + print(" - Conformer는 초기 레이어에서 subsampling 수행") + print(" - 보통 4배 또는 6배 압축") + print(" - 예: 100 -> 25 (4배), 100 -> 16 (6배)") + print() + print("3. Attention Map 형태:") + print(" - [batch_size, num_heads, seq_len, seq_len]") + print(" - seq_len은 해당 레이어에서의 시퀀스 길이") + print(" - subsampling 후에는 모든 레이어가 같은 seq_len") + print() + + # Create test model + model = ConformerCTC( + num_features=80, + num_classes=500, + d_model=256, + num_encoder_layers=6, + nhead=4, + distill_layers=[2, 4], + knowledge_type='attention-map', + ) + + # Test with different input sizes + test_cases = [ + {"seq_len": 50, "name": "Short audio"}, + {"seq_len": 100, "name": "Medium audio"}, + {"seq_len": 200, "name": "Long audio"} + ] + + print("실제 테스트:") + for case in test_cases: + seq_len = case["seq_len"] + batch_size = 2 + + x = torch.randn(batch_size, seq_len, 80) + supervisions = { + 'sequence_idx': torch.arange(batch_size), + 'start_frame': torch.zeros(batch_size), + 'num_frames': torch.tensor([seq_len, seq_len]), + } + + with torch.no_grad(): + outputs = model(x, supervisions) + + print(f"\n{case['name']} (입력 길이: {seq_len}):") + if 'distill_outputs' in outputs and len(outputs['distill_outputs']) > 0: + for i, attn_map in enumerate(outputs['distill_outputs']): + layer_idx = model.distill_layers[i] + attn_seq_len = attn_map.shape[-1] + compression_ratio = seq_len / attn_seq_len + print(f" Layer {layer_idx}: {attn_map.shape} (압축비: {compression_ratio:.1f}x)") + +def explain_encoder_outputs(): + """Explain encoder output structure""" + print("\n🔍 ENCODER OUTPUT vs DISTILL OUTPUT") + print("=" * 50) + + print("Encoder-Output 모드에서:") + print() + print("1. distill_outputs:") + print(" - 선택된 레이어들의 실제 hidden states") + print(" - 각 레이어의 인코더 출력 (feature representations)") + print(" - 형태: [batch_size, seq_len, d_model]") + print(" - Self-distillation에 직접 사용되는 정보") + print() + print("2. distill_hidden:") + print(" - 현재 구현에서는 distill_outputs와 동일") + print(" - 향후 확장을 위한 placeholder") + print(" - 추가적인 컨텍스트나 처리된 정보를 담을 수 있음") + print() + + # Test encoder output mode + model = ConformerCTC( + num_features=80, + num_classes=500, + d_model=256, + num_encoder_layers=6, + nhead=4, + distill_layers=[1, 3, 5], + knowledge_type='encoder-output', + ) + + batch_size = 2 + seq_len = 100 + x = torch.randn(batch_size, seq_len, 80) + supervisions = { + 'sequence_idx': torch.arange(batch_size), + 'start_frame': torch.zeros(batch_size), + 'num_frames': torch.tensor([seq_len, seq_len]), + } + + with torch.no_grad(): + outputs = model(x, supervisions) + + print("실제 예시:") + print(f"입력 크기: {x.shape}") + print(f"선택된 레이어: {model.distill_layers}") + print() + + if 'distill_outputs' in outputs: + print("distill_outputs (각 레이어의 hidden states):") + for i, enc_out in enumerate(outputs['distill_outputs']): + layer_idx = model.distill_layers[i] + print(f" Layer {layer_idx}: {enc_out.shape}") + + print() + print("📝 요약:") + print("- Attention Map: 시퀀스 길이는 subsampling에 의해 결정") + print("- Encoder Output: 각 레이어의 feature representation") + print("- distill_outputs가 self-distillation의 핵심 데이터") + +if __name__ == "__main__": + try: + explain_attention_map_sizes() + explain_encoder_outputs() + print("\n✅ 설명 완료!") + + except Exception as e: + print(f"❌ 오류 발생: {e}") + import traceback + traceback.print_exc() diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc_sd/export.py similarity index 97% rename from egs/aishell/ASR/conformer_ctc/export.py rename to egs/librispeech/ASR/conformer_ctc_sd/export.py index 49871d437..f0bb97560 100755 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/export.py @@ -28,6 +28,7 @@ import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, num_tokens, str2bool @@ -39,7 +40,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=84, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -47,7 +48,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=25, + default=20, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -87,7 +88,7 @@ def get_params() -> AttributeDict: "subsampling_factor": 4, "use_feat_batchnorm": True, "attention_dim": 512, - "nhead": 4, + "nhead": 8, "num_decoder_layers": 6, } ) @@ -101,13 +102,13 @@ def main(): params = get_params() params.update(vars(args)) + logging.info(params) + # Load tokens.txt here token_table = k2.SymbolTable.from_file(params.tokens) num_classes = num_tokens(token_table) + 1 # +1 for the blank - logging.info(params) - device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_H.py similarity index 67% rename from egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py rename to egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_H.py index 72127aebd..e9acf7e0b 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_H.py @@ -7,15 +7,27 @@ on CPU using OpenFST and decoders from kaldi. Usage: - ./tdnn/jit_pretrained_decode_with_H.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --H ./data/lang_phone/H.fst \ - --tokens ./data/lang_phone/tokens.txt \ - ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ - ./download/waves_yesno/0_0_1_0_0_1_1_1.wav +(1) LibriSpeech conformer_ctc -Note that to generate ./tdnn/exp/cpu_jit.pt, + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --H ./data/lang_bpe_500/H.fst \ + --tokens ./data/lang_bpe_500/tokens.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + + +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --H ./data/lang_char/H.fst \ + --tokens ./data/lang_char/tokens.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, you can use ./export.py --jit 1 """ @@ -42,7 +54,7 @@ def get_parser(): type=str, required=True, help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 + You can use ./conformer_ctc/export.py --jit 1 to obtain it """, ) @@ -111,13 +123,29 @@ def decode( H: kaldifst, id2token: Dict[int, str], ) -> List[str]: - decodable = DecodableCtc(nnet_output) + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2token: + A map mapping token ID to token string. + Returns: + Return a list of decoded tokens. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + decoder_opts = FasterDecoderOptions(max_active=3000) decoder = FasterDecoder(H, decoder_opts) decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -129,11 +157,14 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] - # are shifted by 1 during graph construction - hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"] + # tokens are incremented during graph construction + # so they need to be decremented + hyps = [id2token[i - 1] for i in osymbols_out] + # hyps = "".join(hyps).split("▁") + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ return hyps @@ -155,7 +186,7 @@ def main(): logging.info(f"Loading H from {args.H}") H = kaldifst.StdVectorFst.read(args.H) - sample_rate = 8000 + sample_rate = 16000 logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() @@ -163,7 +194,7 @@ def main(): opts.frame_opts.dither = 0 opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 23 + opts.mel_opts.num_bins = 80 opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) @@ -176,18 +207,26 @@ def main(): logging.info("Decoding started") features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - nnet_output = model(features) + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 id2token = read_tokens(args.tokens) hyps = [] for i in range(nnet_output.shape[0]): hyp = decode( - filename=args.sound_files[0], - nnet_output=nnet_output[i], + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], H=H, id2token=id2token, ) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_HL.py similarity index 68% rename from egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py rename to egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_HL.py index f8a057336..5753aa5d3 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_HL.py @@ -7,15 +7,27 @@ on CPU using OpenFST and decoders from kaldi. Usage: - ./tdnn/jit_pretrained_decode_with_HL.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --HL ./data/lang_phone/HL.fst \ - --words ./data/lang_phone/words.txt \ - ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ - ./download/waves_yesno/0_0_1_0_0_1_1_1.wav +(1) LibriSpeech conformer_ctc -Note that to generate ./tdnn/exp/cpu_jit.pt, + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HL ./data/lang_bpe_500/HL.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HL ./data/lang_char/HL.fst \ + --words ./data/lang_char/words.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, you can use ./export.py --jit 1 """ @@ -42,7 +54,7 @@ def get_parser(): type=str, required=True, help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 + You can use ./conformer_ctc/export.py --jit 1 to obtain it """, ) @@ -111,13 +123,29 @@ def decode( HL: kaldifst, id2word: Dict[int, str], ) -> List[str]: - decodable = DecodableCtc(nnet_output) + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + decoder_opts = FasterDecoderOptions(max_active=3000) decoder = FasterDecoder(HL, decoder_opts) decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -129,10 +157,11 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] - hyps = [id2word[i] for i in osymbols_out if id2word[i] != ""] + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] return hyps @@ -154,7 +183,7 @@ def main(): logging.info(f"Loading HL from {args.HL}") HL = kaldifst.StdVectorFst.read(args.HL) - sample_rate = 8000 + sample_rate = 16000 logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() @@ -162,7 +191,7 @@ def main(): opts.frame_opts.dither = 0 opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 23 + opts.mel_opts.num_bins = 80 opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) @@ -175,18 +204,26 @@ def main(): logging.info("Decoding started") features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - nnet_output = model(features) + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 id2word = read_words(args.words) hyps = [] for i in range(nnet_output.shape[0]): hyp = decode( - filename=args.sound_files[0], - nnet_output=nnet_output[i], + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], HL=HL, id2word=id2word, ) diff --git a/egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_HLG.py new file mode 100755 index 000000000..b6e3333ce --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/jit_pretrained_decode_with_HLG.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HLG +on CPU using OpenFST and decoders from kaldi. + +Usage: + +(1) LibriSpeech conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HLG ./data/lang_bpe_500/HLG.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HLG ./data/lang_char/HLG.fst \ + --words ./data/lang_char/words.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HLG=HLG, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc_sd/label_smoothing.py old mode 100644 new mode 100755 similarity index 81% rename from egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py rename to egs/librispeech/ASR/conformer_ctc_sd/label_smoothing.py index 3b94f0c4b..52d2eda3b --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/label_smoothing.py @@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module): mean of the output is taken. (3) "sum": the output will be summed. """ super().__init__() - assert 0.0 <= label_smoothing < 1.0 + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction self.ignore_index = ignore_index self.label_smoothing = label_smoothing self.reduction = reduction @@ -76,15 +77,28 @@ class LabelSmoothingLoss(torch.nn.Module): target = target.clone().reshape(-1) ignored = target == self.ignore_index - target[ignored] = 0 + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) + # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) if self.reduction == "sum": diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc_sd/pretrained.py similarity index 92% rename from egs/mgb2/ASR/conformer_ctc/pretrained.py rename to egs/librispeech/ASR/conformer_ctc_sd/pretrained.py index 0ab2af527..38b60fcb9 100755 --- a/egs/mgb2/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc_sd/pretrained.py @@ -24,7 +24,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from conformer import Conformer @@ -70,11 +69,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -83,10 +80,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -236,9 +232,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -298,6 +294,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) @@ -314,13 +311,20 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, - modified=False, + modified=params.num_classes > 500, device=device, ) @@ -338,9 +342,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "whole-lattice-rescoring", @@ -409,16 +413,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/aishell/ASR/conformer_ctc/__init__.py b/egs/librispeech/ASR/conformer_ctc_sd/quick_test_distillation.sh similarity index 100% rename from egs/aishell/ASR/conformer_ctc/__init__.py rename to egs/librispeech/ASR/conformer_ctc_sd/quick_test_distillation.sh diff --git a/egs/librispeech/ASR/conformer_ctc_sd/scaling.py b/egs/librispeech/ASR/conformer_ctc_sd/scaling.py new file mode 100755 index 000000000..91d64c1df --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/scaling.py @@ -0,0 +1,1014 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import random +from itertools import repeat +from typing import Optional, Tuple + +import torch +import torch.backends.cudnn.rnn as rnn +import torch.nn as nn +from torch import _VF, Tensor + +from icefall.utils import is_jit_tracing + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + ) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + + # sum_dims = [d for d in range(x.ndim) if d != channel_dim] + # The above line is not torch scriptable for torch 1.6.0 + # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa + sum_dims = [] + for d in range(x.ndim): + if d != channel_dim: + sum_dims.append(d) + + xgt0 = x > 0 + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) + + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs + + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.max_factor = max_factor + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors + dtype = x_grad.dtype + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) + + neg_delta_grad = x_grad.abs() * (factor + scale_factor) + return x_grad - neg_delta_grad, None, None, None, None, None, None + + +class GradientFilterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + batch_dim: int, # e.g., 1 + threshold: float, # e.g., 10.0 + *params: Tensor, # module parameters + ) -> Tuple[Tensor, ...]: + if x.requires_grad: + if batch_dim < 0: + batch_dim += x.ndim + ctx.batch_dim = batch_dim + ctx.threshold = threshold + return (x,) + params + + @staticmethod + def backward( + ctx, + x_grad: Tensor, + *param_grads: Tensor, + ) -> Tuple[Tensor, ...]: + eps = 1.0e-20 + dim = ctx.batch_dim + norm_dims = [d for d in range(x_grad.ndim) if d != dim] + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() + median_norm = norm_of_batch.median() + + cutoff = median_norm * ctx.threshold + inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) + mask = 1.0 / (inv_mask + eps) + x_grad = x_grad * mask + + avg_mask = 1.0 / (inv_mask.mean() + eps) + param_grads = [avg_mask * g for g in param_grads] + + return (x_grad, None, None) + tuple(param_grads) + + +class GradientFilter(torch.nn.Module): + """This is used to filter out elements that have extremely large gradients + in batch and the module parameters with soft masks. + + Args: + batch_dim (int): + The batch dimension. + threshold (float): + For each element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch. + """ + + def __init__(self, batch_dim: int = 1, threshold: float = 10.0): + super(GradientFilter, self).__init__() + self.batch_dim = batch_dim + self.threshold = threshold + + def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: + if torch.jit.is_scripting() or is_jit_tracing(): + return (x,) + params + else: + return GradientFilterFunction.apply( + x, + self.batch_dim, + self.threshold, + *params, + ) + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + + def forward(self, x: Tensor) -> Tensor: + if not is_jit_tracing(): + assert x.shape[self.channel_dim] == self.num_channels + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 + return x * scales + + +class ScaledLinear(nn.Linear): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * self.weight_scale.exp() + bias = self.bias * self.bias_scale.exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + initial_speed: this affects how fast the parameter will + learn near the start of training; you can set it to a + value less than one if you suspect that a module + is contributing to instability near the start of training. + Nnote: regardless of the use of this option, it's best to + use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + """ + + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs, + ): + super(ScaledLinear, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3**0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in**-0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + if self.bias is None or self.bias_scale is None: + return None + else: + return self.bias * self.bias_scale.exp() + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + + +class ScaledConv1d(nn.Conv1d): + # See docs for ScaledLinear + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs, + ): + super(ScaledConv1d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + + self.bias_scale: Optional[nn.Parameter] # for torchscript + + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3**0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in**-0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + bias = self.bias + bias_scale = self.bias_scale + if bias is None or bias_scale is None: + return None + else: + return bias * bias_scale.exp() + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + (0,), + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +class ScaledConv2d(nn.Conv2d): + # See docs for ScaledLinear + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs, + ): + super(ScaledConv2d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3**0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in**-0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + # see https://github.com/pytorch/pytorch/issues/24135 + bias = self.bias + bias_scale = self.bias_scale + if bias is None or bias_scale is None: + return None + else: + return bias * bias_scale.exp() + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + (0, 0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) + + +class ScaledLSTM(nn.LSTM): + # See docs for ScaledLinear. + # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` + # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + grad_norm_threshold: float = 10.0, + **kwargs, + ): + if "bidirectional" in kwargs: + assert kwargs["bidirectional"] is False + super(ScaledLSTM, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self._scales_names = [] + self._scales = [] + for name in self._flat_weights_names: + scale_name = name + "_scale" + self._scales_names.append(scale_name) + param = nn.Parameter(initial_scale.clone().detach()) + setattr(self, scale_name, param) + self._scales.append(param) + + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) + + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3**0.5) * std + scale = self.hidden_size**-0.5 + v = scale / std + for idx, name in enumerate(self._flat_weights_names): + if "weight" in name: + nn.init.uniform_(self._flat_weights[idx], -a, a) + with torch.no_grad(): + self._scales[idx] += torch.tensor(v).log() + elif "bias" in name: + nn.init.constant_(self._flat_weights[idx], 0.0) + + def _flatten_parameters(self, flat_weights) -> None: + """Resets parameter data pointer so that they can use faster code paths. + + Right now, this works only if the module is on the GPU and cuDNN is enabled. + Otherwise, it's a no-op. + + This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa + """ + # Short-circuits if _flat_weights is only partially instantiated + if len(flat_weights) != len(self._flat_weights_names): + return + + for w in flat_weights: + if not isinstance(w, Tensor): + return + # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN + # or the tensors in flat_weights are of different dtypes + + first_fw = flat_weights[0] + dtype = first_fw.dtype + for fw in flat_weights: + if ( + not isinstance(fw.data, Tensor) + or not (fw.data.dtype == dtype) + or not fw.data.is_cuda + or not torch.backends.cudnn.is_acceptable(fw.data) + ): + return + + # If any parameters alias, we fall back to the slower, copying code path. This is + # a sufficient check, because overlapping parameter buffers that don't completely + # alias would break the assumptions of the uniqueness check in + # Module.named_parameters(). + unique_data_ptrs = set(p.data_ptr() for p in flat_weights) + if len(unique_data_ptrs) != len(flat_weights): + return + + with torch.cuda.device_of(first_fw): + # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is + # an inplace operation on self._flat_weights + with torch.no_grad(): + if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 + torch._cudnn_rnn_flatten_weight( + flat_weights, + num_weights, + self.input_size, + rnn.get_cudnn_mode(self.mode), + self.hidden_size, + self.proj_size, + self.num_layers, + self.batch_first, + bool(self.bidirectional), + ) + + def _get_flat_weights(self): + """Get scaled weights, and resets their data pointer.""" + flat_weights = [] + for idx in range(len(self._flat_weights_names)): + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) + self._flatten_parameters(flat_weights) + return flat_weights + + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): + # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa + # The change for calling `_VF.lstm()` is: + # self._flat_weights -> self._get_flat_weights() + if hx is None: + h_zeros = torch.zeros( + self.num_layers, + input.size(1), + self.proj_size if self.proj_size > 0 else self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers, + input.size(1), + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + + self.check_forward_args(input, hx, None) + + flat_weights = self._get_flat_weights() + input, *flat_weights = self.grad_filter(input, *flat_weights) + + result = _VF.lstm( + input, + hx, + flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + + output = result[0] + hidden = result[1:] + return output, hidden + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + + Args: + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + balance_prob: the probability to apply the ActivationBalancer. + """ + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + balance_prob: float = 0.25, + ): + super(ActivationBalancer, self).__init__() + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + assert 0 < balance_prob <= 1, balance_prob + self.balance_prob = balance_prob + + def forward(self, x: Tensor) -> Tensor: + if random.random() >= self.balance_prob: + return x + + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor / self.balance_prob, + self.min_abs, + self.max_abs, + ) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + s, y = ctx.saved_tensors + return (y * (1 - s) + s) * y_grad + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or is_jit_tracing(): + return x * torch.sigmoid(x - 1.0) + else: + return DoubleSwishFunction.apply(x) + + +class ScaledEmbedding(nn.Module): + r"""This is a modified version of nn.Embedding that introduces a learnable scale + on the parameters. Note: due to how we initialize it, it's best used with + schedulers like Noam that have a warmup period. + + It is a simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + initial_speed (float, optional): This affects how fast the parameter will + learn near the start of training; you can set it to a value less than + one if you suspect that a module is contributing to instability near + the start of training. Note: regardless of the use of this option, + it's best to use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + + """ + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + initial_speed: float = 1.0, + ) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters(initial_speed) + + def reset_parameters(self, initial_speed: float = 1.0) -> None: + std = 0.1 / initial_speed + nn.init.normal_(self.weight, std=std) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + scale = self.scale.exp() + if input.numel() < self.num_embeddings: + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) + else: + return F.embedding( + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self) -> str: + # s = "{num_embeddings}, {embedding_dim}, scale={scale}" + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + return s.format(**self.__dict__) + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x.requires_grad = True + m = DoubleSwish() + torch.autograd.gradcheck(m, x) + + +def _test_scaled_lstm(): + N, L = 2, 30 + dim_in, dim_hidden = 10, 20 + m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True) + x = torch.randn(L, N, dim_in) + h0 = torch.randn(1, N, dim_hidden) + c0 = torch.randn(1, N, dim_hidden) + y, (h, c) = m(x, (h0, c0)) + assert y.shape == (L, N, dim_hidden) + assert h.shape == (1, N, dim_hidden) + assert c.shape == (1, N, dim_hidden) + + +def _test_grad_filter(): + threshold = 50.0 + time, batch, channel = 200, 5, 128 + grad_filter = GradientFilter(batch_dim=1, threshold=threshold) + + for i in range(2): + x = torch.randn(time, batch, channel, requires_grad=True) + w = nn.Parameter(torch.ones(5)) + b = nn.Parameter(torch.zeros(5)) + + x_out, w_out, b_out = grad_filter(x, w, b) + + w_out_grad = torch.randn_like(w) + b_out_grad = torch.randn_like(b) + x_out_grad = torch.rand_like(x) + if i % 2 == 1: + # The gradient norm of the first element must be larger than + # `threshold * median`, where `median` is the median value + # of gradient norms of all elements in batch. + x_out_grad[:, 0, :] = torch.full((time, channel), threshold) + + torch.autograd.backward( + [x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad] + ) + + print( + "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa + i % 2 == 1, + ) + + print( + "_test_grad_filter: x_out_grad norm = ", + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), + ) + print( + "_test_grad_filter: x.grad norm = ", + (x.grad**2).mean(dim=(0, 2)).sqrt(), + ) + print("_test_grad_filter: w_out_grad = ", w_out_grad) + print("_test_grad_filter: w.grad = ", w.grad) + print("_test_grad_filter: b_out_grad = ", b_out_grad) + print("_test_grad_filter: b.grad = ", b.grad) + + +if __name__ == "__main__": + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() + _test_scaled_lstm() + _test_grad_filter() diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc_sd/subsampling.py old mode 100644 new mode 100755 similarity index 100% rename from egs/aishell/ASR/conformer_ctc/subsampling.py rename to egs/librispeech/ASR/conformer_ctc_sd/subsampling.py diff --git a/egs/aishell/ASR/conformer_mmi/__init__.py b/egs/librispeech/ASR/conformer_ctc_sd/test_attention_distillation.py similarity index 100% rename from egs/aishell/ASR/conformer_mmi/__init__.py rename to egs/librispeech/ASR/conformer_ctc_sd/test_attention_distillation.py diff --git a/egs/aishell/ASR/local/__init__.py b/egs/librispeech/ASR/conformer_ctc_sd/test_clean_noisy.py similarity index 100% rename from egs/aishell/ASR/local/__init__.py rename to egs/librispeech/ASR/conformer_ctc_sd/test_clean_noisy.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py b/egs/librispeech/ASR/conformer_ctc_sd/test_ema_teacher.py similarity index 100% rename from egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py rename to egs/librispeech/ASR/conformer_ctc_sd/test_ema_teacher.py diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/__init__.py b/egs/librispeech/ASR/conformer_ctc_sd/test_self_distillation.py similarity index 100% rename from egs/aishell/ASR/tdnn_lstm_ctc/__init__.py rename to egs/librispeech/ASR/conformer_ctc_sd/test_self_distillation.py diff --git a/egs/librispeech/ASR/conformer_ctc_sd/train.py b/egs/librispeech/ASR/conformer_ctc_sd/train.py new file mode 100755 index 000000000..4e0b33bd5 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_sd/train.py @@ -0,0 +1,1674 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./conformer_ctc/train.py \ + --exp-dir ./conformer_ctc/exp \ + --world-size 4 \ + --full-libri 1 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import sentencepiece as spm +from asr_datamodule import LibriSpeechAsrDataModule +from conformer_ctc import ConformerCTC, compute_distillation_loss +from ema_teacher import EMATeacher +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam +from decode import decode_dataset, save_results + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + load_averaged_model, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +# Global counter for validation samples to control terminal logging frequency +_VALIDATION_SAMPLE_COUNTER = 0 + + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=78, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="./conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="./data/lang_phone", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--bpe-dir", + type=str, + default="./data/lang_bpe_5000", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=0, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--warm-step", + type=int, + default=30000, + help="Number of warmup steps for Noam optimizer. " + "Recommended: 30000 (with data aug), 15000-20000 (without data aug)", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + parser.add_argument( + "--sanity-check", + type=str2bool, + default=True, + help="About Sanity check process", + ) + + # Self-distillation arguments + parser.add_argument( + "--enable-self-distillation", + type=str2bool, + default=True, + help="Enable self-distillation training between clean and noisy samples", + ) + + parser.add_argument( + "--distill-layers", + type=str, + default="6", + help="Which encoder layer(s) to use for distillation (0-based). " + "Can be a single layer (e.g., '6') or comma-separated list (e.g., '4,6,8'). " + "Clean and noisy outputs from these layers will be compared.", + ) + + parser.add_argument( + "--distill-loss-type", + type=str, + default="mse", + choices=["mse", "cosine"], + help="Type of loss for self-distillation: 'mse' for Mean Squared Error, " + "'cosine' for cosine similarity loss.", + ) + + parser.add_argument( + "--alpha", + type=float, + default=0.7, + help="Weight for self-distillation loss. Total loss = ctc_loss + distill_weight * distill_loss", + ) + + parser.add_argument( + "--distill-aggregation", + type=str, + default="layer_avg", + choices=["layer_avg", "output_avg"], + help="How to aggregate multi-layer distillation losses: " + "'layer_avg' computes loss for each layer and averages them, " + "'output_avg' averages the layer outputs first then computes a single loss.", + ) + + parser.add_argument( + "--knowledge", + type=str, + default="encoder-output", + choices=["encoder-output", "attention-map"], + help="Type of knowledge to use for self-distillation: " + "'encoder-output' uses intermediate encoder layer outputs, " + "'attention-map' uses attention weights from self-attention layers.", + ) + + parser.add_argument( + "--distill-temperature", + type=float, + default=1.0, + help="Temperature for attention map distillation (used with KL divergence). " + "Higher values make attention distributions smoother.", + ) + + # EMA Teacher Model Arguments + parser.add_argument( + "--ema-decay", + type=float, + default=0.999, + help="EMA decay rate for teacher model updates. " + "Higher values (closer to 1.0) make teacher model change more slowly. " + "Typical values: 0.999, 0.9999", + ) + + parser.add_argument( + "--ema-start-step", + type=int, + default=1000, + help="Step number to start EMA teacher model updates. " + "Before this step, teacher model equals student model.", + ) + + parser.add_argument( + "--method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - ctc-decoding: CTC greedy search or beam search. + - nbest-rescoring: Use N-best list for LM rescoring. + - whole-lattice-rescoring: Use whole lattice for LM rescoring. + - attention-decoder: Use attention decoder rescoring. + - rnn-lm: Use RNN LM for rescoring. + """, + ) + + parser.add_argument( + "--enable-validation", + type=str2bool, + default=True, + help="Enable validation during training. Set to False to disable validation completely.", + ) + + parser.add_argument( + "--valid-interval", + type=int, + default=3000, + help="Run validation every N batches. Increase this to validate less frequently.", + ) + + parser.add_argument( + "--validation-decoding-method", + type=str, + default="greedy", + choices=["greedy", "beam"], + help="Decoding method for validation: 'greedy' for faster validation, 'beam' for more accurate WER.", + ) + + parser.add_argument( + "--validation-search-beam", + type=float, + default=10.0, + help="Search beam size for validation decoding (only used with beam search).", + ) + + parser.add_argument( + "--validation-output-beam", + type=float, + default=5.0, + help="Output beam size for validation decoding (only used with beam search).", + ) + + parser.add_argument( + "--validation-skip-wer", + type=str2bool, + default=False, + help="Skip WER computation during validation for faster validation (only compute loss).", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # Default value, will be overridden by args + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 256, + "nhead": 4, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for decoding/validation + "search_beam": 20.0, + "output_beam": 8.0, + "min_active_states": 30, + "max_active_states": 10000, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 30000, + "env_info": get_env_info(), + # Self-distillation parameters + "enable_self_distillation": True, + "distill_layer": 6, + "distill_loss_type": "mse", + "distill_weight": 0.1, + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + ema_teacher: Optional[EMATeacher] = None, +) -> Optional[dict]: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + # First try to find checkpoint in models directory + models_dir = params.exp_dir / "models" + filename = models_dir / f"epoch-{params.start_epoch-1}.pt" + + # If not found in models directory, try the old location for backward compatibility + if not filename.exists(): + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + + if not filename.exists(): + logging.warning(f"Checkpoint not found at {filename}") + return + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + # Try to load EMA teacher checkpoint if it exists + if ema_teacher is not None: + ema_filename = models_dir / f"epoch-{params.start_epoch-1}-ema-teacher.pt" + if not ema_filename.exists(): + # Try old location for backward compatibility + ema_filename = params.exp_dir / f"epoch-{params.start_epoch-1}-ema-teacher.pt" + + if ema_filename.exists(): + try: + ema_state_dict = torch.load(ema_filename, map_location='cpu') + ema_teacher.load_state_dict(ema_state_dict) + logging.info(f"Loaded EMA teacher checkpoint from {ema_filename}") + saved_params['ema_teacher'] = ema_state_dict + except Exception as e: + logging.warning(f"Failed to load EMA teacher checkpoint: {e}") + else: + logging.info("EMA teacher checkpoint not found, will initialize from student model") + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, + suffix: str = "", + wer_value: Optional[float] = None, + step: Optional[int] = None, + ema_teacher: Optional[EMATeacher] = None, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + wer_value: + WER value to include in filename (optional). + step: + Training step to include in filename instead of epoch (optional). + """ + if rank != 0: + return + + # Create models directory if it doesn't exist + models_dir = params.exp_dir / "models" + models_dir.mkdir(exist_ok=True) + + if suffix: + # Use step instead of epoch for validation checkpoints + epoch_or_step = step if step is not None else params.cur_epoch + if wer_value is not None: + filename = models_dir / f"step-{epoch_or_step}-{suffix}-wer{wer_value:.2f}.pt" + else: + filename = models_dir / f"step-{epoch_or_step}-{suffix}.pt" + else: + filename = models_dir / f"epoch-{params.cur_epoch}.pt" + + # Save main checkpoint + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + # Save EMA teacher model separately if it exists + if ema_teacher is not None: + ema_filename = models_dir / f"epoch-{params.cur_epoch}-ema-teacher.pt" + torch.save(ema_teacher.state_dict(), ema_filename) + logging.info(f"EMA teacher checkpoint saved to {ema_filename}") + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = models_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = models_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info(f"Checkpoint saved successfully to {filename}") + # Remove the print statement that might be causing issues + # print("Saving All Done!") + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, + ema_teacher: Optional[EMATeacher] = None, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss with optional self-distillation. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of ConformerCTC. + batch: + A batch of data. Can contain both clean and noisy samples for self-distillation. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. + is_training: + True for training. False for validation. + """ + device = graph_compiler.device + + # Handle clean-noisy batch structure for self-distillation + if 'clean' in batch and 'noisy' in batch and params.enable_self_distillation: + # Self-distillation mode with clean-noisy samples + clean_feature = batch['clean']['inputs'] + noisy_feature = batch['noisy']['inputs'] + clean_supervisions = batch['clean']['supervisions'] + noisy_supervisions = batch['noisy']['supervisions'] + + # Use noisy samples as primary for CTC loss computation + feature = noisy_feature + supervisions = noisy_supervisions + + # Move to device + clean_feature = clean_feature.to(device) + noisy_feature = noisy_feature.to(device) + + use_self_distillation = True + else: + # Normal mode or self-distillation disabled + feature = batch["inputs"] + supervisions = batch["supervisions"] + feature = feature.to(device) + + clean_feature = None + noisy_feature = None + clean_supervisions = None + use_self_distillation = False + + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + + with torch.set_grad_enabled(is_training): + # Forward pass through model (noisy sample) + model_output = model(feature, supervisions) + + # Extract outputs from ConformerCTC + nnet_output = model_output['ctc_output'] # (N, T, C) + distill_hidden = model_output['distill_hidden'] # (N, T, d_model) + encoder_memory = model_output['encoder_out'] # (N, T, d_model) + memory_mask = None # ConformerCTC doesn't return mask, we'll create it from lengths + + # Create memory mask from output lengths + output_lens = model_output['encoder_out_lens'] + max_len = nnet_output.size(1) + memory_mask = torch.arange(max_len, device=device)[None, :] >= output_lens[:, None] + + # Self-distillation computation + distillation_loss = torch.tensor(0.0, device=device) + if use_self_distillation and params.enable_self_distillation: + # Use EMA teacher model if available, otherwise use clean samples + if ema_teacher is not None: + # Get teacher model outputs (teacher is always in eval mode) + teacher_model = ema_teacher.get_teacher_model() + with torch.no_grad(): + # Use noisy samples for teacher model (student samples) + teacher_model_output = teacher_model(feature, supervisions) + elif clean_feature is not None: + # Fallback to clean samples as teacher + teacher_model_output = model(clean_feature, clean_supervisions) + else: + # No teacher available, skip distillation + teacher_model_output = None + + if teacher_model_output is not None: + # Parse distillation layers from comma-separated string + try: + distill_layers = [int(x.strip()) for x in params.distill_layers.split(',')] + except: + distill_layers = [int(params.distill_layers)] + + if params.knowledge == "encoder-output": + # Extract encoder outputs for distillation + if 'distill_outputs' in model_output and 'distill_outputs' in teacher_model_output: + teacher_outputs = teacher_model_output['distill_outputs'] + student_outputs = model_output['distill_outputs'] + + # Import the multi-layer distillation function + from conformer_ctc import compute_multi_layer_distillation_loss + + distillation_loss = compute_multi_layer_distillation_loss( + teacher_knowledge=teacher_outputs, + student_knowledge=student_outputs, + knowledge_lens=output_lens, + layer_indices=distill_layers, + loss_type=params.distill_loss_type, + knowledge_type="encoder-output", + aggregation=params.distill_aggregation, + ) + elif 'distill_hidden' in model_output and 'distill_hidden' in teacher_model_output: + # Fallback to single layer distillation for backward compatibility + from conformer_ctc import compute_distillation_loss + distillation_loss = compute_distillation_loss( + teacher_knowledge=teacher_model_output['distill_hidden'], + student_knowledge=model_output['distill_hidden'], + knowledge_lens=output_lens, + loss_type=params.distill_loss_type, + knowledge_type="encoder-output", + ) + + elif params.knowledge == "attention-map": + # Extract attention maps for distillation + if 'attention_maps' in model_output and 'attention_maps' in teacher_model_output: + teacher_attention = teacher_model_output['attention_maps'] + student_attention = model_output['attention_maps'] + + # Import the multi-layer distillation function + from conformer_ctc import compute_multi_layer_distillation_loss + + distillation_loss = compute_multi_layer_distillation_loss( + teacher_knowledge=teacher_attention, + student_knowledge=student_attention, + knowledge_lens=output_lens, + layer_indices=distill_layers, + loss_type="kl", # Always use KL divergence for attention maps + knowledge_type="attention-map", + aggregation=params.distill_aggregation, + temperature=params.distill_temperature, + ) + else: + logging.warning("Attention maps not found in model output. Distillation disabled for this batch.") + distillation_loss = torch.tensor(0.0, device=device) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + # Compute CTC loss + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=max(params.subsampling_factor - 1, 10), + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + # Attention loss computation (if applicable) + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + total_loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + att_loss = torch.tensor([0]) + total_loss = ctc_loss + + # Add self-distillation loss + if use_self_distillation and distillation_loss.item() > 0: + total_loss = (1.0 - params.alpha) * total_loss + params.alpha * distillation_loss + + assert total_loss.requires_grad == is_training + + # Metrics tracking + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + info["att_loss"] = att_loss.detach().cpu().item() + info["distill_loss"] = distillation_loss.detach().cpu().item() + info["loss"] = total_loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return total_loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + epoch: int = 1, + quick_validation: bool = True, # Add option for quick validation + rank: int = 0, # Add rank parameter + tb_writer: Optional[SummaryWriter] = None, # Add TensorBoard writer parameter +) -> MetricsTracker: + + + model.eval() + + with torch.no_grad(): + device = next(model.parameters()).device + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + logging.info("Validation loss computation completed") + + # Always compute WER for analysis + logging.info("Starting WER computation...") + + # Use the existing graph_compiler instead of creating a new one + # to ensure device compatibility in DDP training + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + # Read vocab size from tokens.txt + tokens_file = params.lang_dir / "tokens.txt" + with open(tokens_file, 'r', encoding='utf-8') as f: + vocab_size = len(f.readlines()) + max_token_id = vocab_size - 1 + + # WER calculation with proper device handling + if params.att_rate == 0.0: + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + # For BPE mode, create a simple word table from tokens + if "lang_bpe" in str(params.lang_dir): + # Read tokens and create a simple word table mapping + tokens_file = params.lang_dir / "tokens.txt" + if tokens_file.exists(): + word_table = {} + with open(tokens_file, 'r') as f: + for line in f: + if line.strip(): + parts = line.strip().split() + if len(parts) >= 2: + token, idx = parts[0], parts[1] + word_table[token] = int(idx) + else: + word_table = None + else: + # Phone mode: use lexicon word table + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + + + + # Use validation-specific decoding parameters + if params.validation_decoding_method == "greedy": + logging.info("Starting decode_dataset with GREEDY decoding...") + # Override beam parameters for greedy decoding + original_search_beam = params.search_beam + original_output_beam = params.output_beam + params.search_beam = 1.0 # Greedy = beam size 1 + params.output_beam = 1.0 + else: + logging.info(f"Starting decode_dataset with BEAM search (search_beam={params.validation_search_beam}, output_beam={params.validation_output_beam})...") + # Use validation-specific beam parameters + original_search_beam = params.search_beam + original_output_beam = params.output_beam + params.search_beam = params.validation_search_beam + params.output_beam = params.validation_output_beam + + try: + results_dict = decode_dataset( + dl=valid_dl, + params=params, + model=model, + rnn_lm_model=None, # For CTC validation, we don't use RNN LM + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=word_table, + sos_id=sos_id, + eos_id=eos_id, + ) + + except Exception as e: + logging.error(f"decode_dataset failed: {e}") + logging.error("Skipping WER computation for this validation") + # Restore original beam parameters + params.search_beam = original_search_beam + params.output_beam = original_output_beam + + logging.info(f"Validation loss: {loss_value:.4f}") + return tot_loss, None + + # Restore original beam parameters + params.search_beam = original_search_beam + params.output_beam = original_output_beam + + logging.info("Starting save_results...") + + wer_results = save_results(params=params, test_set_name=f"epoch_{epoch}_validation", results_dict=results_dict) + + # Log WER results + if wer_results: + for method, wer_value in wer_results.items(): + logging.info(f"Dataset-level WER ({method}): {wer_value:.2f}% (total errors/total words)") + # Log each WER method to TensorBoard + if rank == 0 and tb_writer is not None: + tb_writer.add_scalar(f"validation/wer_{method}", wer_value, params.batch_idx_train) + else: + logging.info("Validation WER: N/A") + + # Log some example predictions vs ground truth for inspection + log_prediction_examples(results_dict, max_examples=3) + + # Log examples to TensorBoard if available + if rank == 0 and tb_writer is not None: + log_validation_examples_to_tensorboard(results_dict, tb_writer, params.batch_idx_train, max_examples=5) + + # Calculate overall WER statistics if we have results + overall_wer = None + if wer_results: + # Find the main WER method (usually the first one or the one with 'wer' in the name) + main_wer_key = None + for key in wer_results.keys(): + if 'wer' in key.lower() or 'word_error_rate' in key.lower(): + main_wer_key = key + break + + if main_wer_key is None and wer_results: + # If no specific WER key found, use the first one + main_wer_key = list(wer_results.keys())[0] + + if main_wer_key: + overall_wer = wer_results[main_wer_key] + logging.info(f"Main dataset-level WER ({main_wer_key}): {overall_wer:.2f}% (total errors/total words)") + # Log the main/total WER to TensorBoard + if rank == 0 and tb_writer is not None: + tb_writer.add_scalar("validation/total_wer", overall_wer, params.batch_idx_train) + tb_writer.add_scalar("validation/wer_dataset_level", overall_wer, params.batch_idx_train) + + # Final logging of validation results + logging.info(f"Validation loss: {loss_value:.4f}") + if overall_wer is not None: + logging.info(f"Total validation WER: {overall_wer:.2f}% (dataset-level)") + # Log the final total WER to TensorBoard + if rank == 0 and tb_writer is not None: + tb_writer.add_scalar("validation/loss", loss_value, params.batch_idx_train) + tb_writer.add_scalar("validation/total_wer", overall_wer, params.batch_idx_train) + else: + logging.info("Validation WER: N/A") + + return tot_loss, overall_wer + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, + ema_teacher: Optional[EMATeacher] = None, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ema_teacher=ema_teacher, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + # Update EMA teacher model after optimizer step + if ema_teacher is not None and params.batch_idx_train >= params.ema_start_step: + ema_teacher.update(model) + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0 and params.enable_validation: + logging.info(f"Computing validation loss (rank {rank})") + + + # Use quick validation for frequent checks, full validation less frequently + quick_val = (params.batch_idx_train % (params.valid_interval * 5) != 0) + valid_info, validation_wer = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + epoch=params.cur_epoch, + quick_validation=quick_val, + rank=rank, + tb_writer=tb_writer, + ) + + + # Log validation results with WER if available + if validation_wer is not None: + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}, WER: {validation_wer:.2f}%") + else: + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + + # Save checkpoint after validation (only rank 0) + if rank == 0: + logging.info(f"Saving checkpoint after validation at batch {batch_idx}") + try: + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + suffix=f"val-{batch_idx}", + wer_value=validation_wer, + step=batch_idx, + ) + logging.info(f"Checkpoint saved successfully for batch {batch_idx}") + except Exception as e: + logging.error(f"Failed to save checkpoint: {e}") + # Continue training even if checkpoint saving fails + model.train() + + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + # Write WER to TensorBoard if validation results file exists and contains WER + wer_summary_file = params.exp_dir / f"wer-summary-epoch_{params.cur_epoch}_validation.txt" + if wer_summary_file.exists(): + try: + with open(wer_summary_file, 'r') as f: + lines = f.readlines() + for line in lines[1:]: # Skip header line + if line.strip(): + parts = line.strip().split('\t') + if len(parts) >= 2: + method_name = parts[0] + wer_value = float(parts[1]) + tb_writer.add_scalar(f"train/valid_WER_{method_name}", wer_value, params.batch_idx_train) + except Exception as e: + logging.warning(f"Could not log WER to TensorBoard: {e}") + + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(f"Warmup steps: {params.warm_step}") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + # Read vocab size from tokens.txt + tokens_file = params.lang_dir / "tokens.txt" + with open(tokens_file, 'r', encoding='utf-8') as f: + num_classes = len(f.readlines()) + max_token_id = num_classes - 1 + elif "lang_phone" in str(params.lang_dir): + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + logging.info("About to create model") + + # Determine encoder layers for self-distillation + num_encoder_layers = 12 # Default value, can be made configurable + + # Parse distillation layers + distill_layers = [] + if params.enable_self_distillation and params.distill_layers: + try: + distill_layers = [int(x.strip()) for x in params.distill_layers.split(',')] + logging.info(f"Self-distillation ENABLED with layers: {distill_layers}") + logging.info(f"Knowledge type: {params.knowledge}") + logging.info(f"Loss type: {params.distill_loss_type}") + logging.info(f"Aggregation: {params.distill_aggregation}") + if params.knowledge == "attention-map": + logging.info(f"Temperature: {params.distill_temperature}") + except: + distill_layers = [int(params.distill_layers)] + logging.info(f"Self-distillation ENABLED with single layer: {distill_layers[0]}") + else: + logging.info("Self-distillation DISABLED") + + model = ConformerCTC( + num_features=params.feature_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + num_encoder_layers=num_encoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + # Self-distillation parameters + distill_layers=distill_layers, + knowledge_type=params.knowledge, + ) + + checkpoints = load_checkpoint_if_available( + params=params, + model=model, + ema_teacher=ema_teacher + ) + + model.to(device) + + # Initialize EMA Teacher Model for self-distillation + ema_teacher = None + if params.enable_self_distillation: + logging.info(f"Initializing EMA teacher model with decay={params.ema_decay}, start_step={params.ema_start_step}") + ema_teacher = EMATeacher(model, decay=params.ema_decay, device=device) + + if world_size > 1: + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + # Use only dev_clean for faster validation (dev_other can be added later) + valid_cuts = librispeech.dev_clean_cuts() + # valid_cuts += librispeech.dev_other_cuts() # Comment out for faster validation + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + logging.info(f"Validation set size: {len(valid_cuts)} utterances") + + if params.sanity_check: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + else: pass + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ema_teacher=ema_teacher, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ema_teacher=ema_teacher, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def log_prediction_examples(results_dict, max_examples=5, force_log=False): + """ + Log a few examples of ground truth vs predicted text for validation inspection. + Only logs to terminal every 50 validation samples to reduce clutter. + + Args: + results_dict: Dictionary containing decoding results + max_examples: Maximum number of examples to log + force_log: Force logging regardless of sample counter + """ + global _VALIDATION_SAMPLE_COUNTER + + if not results_dict: + return + + # Get the first method's results (usually there's only one method in validation) + first_method = list(results_dict.keys())[0] + results = results_dict[first_method] + + if not results: + return + + # Update the validation sample counter + _VALIDATION_SAMPLE_COUNTER += len(results) + + # Only log to terminal every 50 samples (or when forced) + should_log_to_terminal = force_log or (_VALIDATION_SAMPLE_COUNTER % 50 == 0) or (_VALIDATION_SAMPLE_COUNTER <= 50) + + if not should_log_to_terminal: + # Still compute and log basic statistics, just not the detailed examples + total_sample_wer = 0 + valid_samples = 0 + + for result in results: + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + ref_word_list = ref_text.split() + hyp_word_list = hyp_text.split() + + if len(ref_word_list) > 0: + import difflib + matcher = difflib.SequenceMatcher(None, ref_word_list, hyp_word_list) + word_errors = len(ref_word_list) + len(hyp_word_list) - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / len(ref_word_list)) * 100 + total_sample_wer += utt_wer + valid_samples += 1 + + # Log summary info only + if valid_samples > 0: + avg_example_wer = total_sample_wer / valid_samples + logging.info(f"Validation batch processed: {valid_samples} samples " + f"(total samples processed: {_VALIDATION_SAMPLE_COUNTER}, detailed examples every 50 samples)") + return + + # Full detailed logging when we hit the 50-sample threshold + logging.info(f"Detailed validation examples (sample #{_VALIDATION_SAMPLE_COUNTER - len(results) + 1}-{_VALIDATION_SAMPLE_COUNTER}):") + + # Select diverse examples: some short, some long, some with errors, some perfect + selected_examples = [] + + # Try to get diverse examples by length and error type + perfect_matches = [] + error_cases = [] + + for result in results: + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + if ref_text.split() == hyp_text.split(): + perfect_matches.append(result) + else: + error_cases.append(result) + + # Mix perfect matches and error cases + selected_examples = error_cases[:max_examples-1] + perfect_matches[:1] + if len(selected_examples) < max_examples: + selected_examples.extend(results[:max_examples - len(selected_examples)]) + + selected_examples = selected_examples[:max_examples] + + logging.info("=" * 80) + logging.info(f"VALIDATION EXAMPLES (showing {len(selected_examples)} samples):") + logging.info("=" * 80) + + total_sample_wer = 0 + valid_samples = 0 + + for i, result in enumerate(selected_examples): + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + + # Convert word lists to strings + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + logging.info(f"Example {i+1} (ID: {cut_id}):") + logging.info(f" REF: {ref_text}") + logging.info(f" HYP: {hyp_text}") + + # Simple word error analysis + ref_word_list = ref_text.split() + hyp_word_list = hyp_text.split() + + if ref_word_list == hyp_word_list: + logging.info(f" --> ✅ PERFECT MATCH ({len(ref_word_list)} words, WER: 0.0%)") + total_sample_wer += 0.0 + valid_samples += 1 + else: + # Basic error analysis + ref_len = len(ref_word_list) + hyp_len = len(hyp_word_list) + + # Calculate simple WER for this utterance + import difflib + matcher = difflib.SequenceMatcher(None, ref_word_list, hyp_word_list) + word_errors = ref_len + hyp_len - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / ref_len * 100) if ref_len > 0 else 0 + total_sample_wer += utt_wer + valid_samples += 1 + + # Find common words for basic analysis + ref_set = set(ref_word_list) + hyp_set = set(hyp_word_list) + missing_words = ref_set - hyp_set + extra_words = hyp_set - ref_set + + error_info = f"WER: {utt_wer:.1f}%, REF: {ref_len} words, HYP: {hyp_len} words" + if missing_words and len(missing_words) <= 3: + error_info += f", Missing: {list(missing_words)}" + elif missing_words: + error_info += f", Missing: {len(missing_words)} words" + + if extra_words and len(extra_words) <= 3: + error_info += f", Extra: {list(extra_words)}" + elif extra_words: + error_info += f", Extra: {len(extra_words)} words" + + logging.info(f" --> ❌ ERRORS ({error_info})") + logging.info("") + + # Log average WER for the examples + if valid_samples > 0: + avg_example_wer = total_sample_wer / valid_samples + logging.info(f"Average WER for these {valid_samples} examples: {avg_example_wer:.2f}%") + + logging.info("=" * 80) + + +def log_validation_examples_to_tensorboard(results_dict, tb_writer, step, max_examples=5): + """ + Log validation examples to TensorBoard as text. + + Args: + results_dict: Dictionary containing decoding results + tb_writer: TensorBoard writer + step: Current training step + max_examples: Maximum number of examples to log + """ + if not results_dict or tb_writer is None: + return + + # Get the first method's results + first_method = list(results_dict.keys())[0] + results = results_dict[first_method] + + if not results: + return + + # Select diverse examples + selected_examples = [] + perfect_matches = [] + error_cases = [] + + for result in results: + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + if ref_text.split() == hyp_text.split(): + perfect_matches.append(result) + else: + error_cases.append(result) + + # Mix error cases and perfect matches + selected_examples = error_cases[:max_examples-1] + perfect_matches[:1] + if len(selected_examples) < max_examples: + selected_examples.extend(results[:max_examples - len(selected_examples)]) + + selected_examples = selected_examples[:max_examples] + + # Create text to log to TensorBoard + tb_text = "## Validation Examples\n\n" + + total_wer = 0 + valid_count = 0 + + for i, result in enumerate(selected_examples): + if len(result) >= 3: + cut_id, ref_words, hyp_words = result[0], result[1], result[2] + + ref_text = " ".join(ref_words) if isinstance(ref_words, list) else str(ref_words) + hyp_text = " ".join(hyp_words) if isinstance(hyp_words, list) else str(hyp_words) + + tb_text += f"**Example {i+1} (ID: {cut_id})**\n\n" + tb_text += f"- **REF:** {ref_text}\n" + tb_text += f"- **HYP:** {hyp_text}\n" + + # Calculate simple WER for this utterance + ref_word_list = ref_text.split() + hyp_word_list = hyp_text.split() + + if ref_word_list == hyp_word_list: + tb_text += f"- **Result:** ✅ PERFECT MATCH ({len(ref_word_list)} words, WER: 0.0%)\n\n" + total_wer += 0.0 + valid_count += 1 + else: + import difflib + matcher = difflib.SequenceMatcher(None, ref_word_list, hyp_word_list) + word_errors = len(ref_word_list) + len(hyp_word_list) - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) + utt_wer = (word_errors / len(ref_word_list) * 100) if len(ref_word_list) > 0 else 0 + tb_text += f"- **Result:** ❌ WER: {utt_wer:.1f}% (REF: {len(ref_word_list)} words, HYP: {len(hyp_word_list)} words)\n\n" + total_wer += utt_wer + valid_count += 1 + + # Add summary statistics + if valid_count > 0: + avg_wer = total_wer / valid_count + tb_text += f"**Summary:** Average WER for {valid_count} examples: {avg_wer:.2f}%\n\n" + + # Log to TensorBoard + tb_writer.add_text("Validation/Examples", tb_text, step) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.bpe_dir = Path(args.bpe_dir) + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc_sd/transformer.py old mode 100644 new mode 100755 similarity index 100% rename from egs/gigaspeech/ASR/conformer_ctc/transformer.py rename to egs/librispeech/ASR/conformer_ctc_sd/transformer.py diff --git a/egs/librispeech/ASR/conformer_mmi/__init__.py b/egs/librispeech/ASR/conformer_mmi/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/test_emformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/create_chime4_test.py b/egs/librispeech/ASR/create_chime4_test.py new file mode 100644 index 000000000..79a64151a --- /dev/null +++ b/egs/librispeech/ASR/create_chime4_test.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +""" +Simple CHiME-4 test dataloader creation script. +Creates a small subset for quick testing. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List + +from lhotse import CutSet, Recording, RecordingSet, SupervisionSegment, SupervisionSet +from lhotse.dataset import K2SpeechRecognitionDataset +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse import Fbank, FbankConfig +from torch.utils.data import DataLoader + + +def create_simple_chime4_test_loader( + audio_root: Path = Path("/home/nas/DB/CHiME4/data/audio/16kHz/isolated"), + transcript_root: Path = Path("/home/nas/DB/CHiME4/data/transcriptions"), + max_files: int = 10 +) -> DataLoader: + """Create a simple test dataloader with limited CHiME-4 files.""" + + logging.info(f"Creating simple CHiME-4 test loader with max {max_files} files") + + # Focus on dt05_bth (clean booth) for simplicity + audio_dir = audio_root / "dt05_bth" + transcript_dir = transcript_root / "dt05_bth" + + if not audio_dir.exists(): + raise FileNotFoundError(f"Audio directory not found: {audio_dir}") + if not transcript_dir.exists(): + raise FileNotFoundError(f"Transcript directory not found: {transcript_dir}") + + # Get limited audio files + wav_files = sorted(list(audio_dir.glob("*.wav")))[:max_files] + logging.info(f"Found {len(wav_files)} audio files to process") + + # Parse transcriptions from individual .trn files + transcriptions = {} + for trn_file in transcript_dir.glob("*.trn"): + try: + with open(trn_file, 'r', encoding='utf-8') as f: + line = f.read().strip() + if line: + parts = line.split(' ', 1) + if len(parts) == 2: + utterance_id = parts[0] + text = parts[1] + transcriptions[utterance_id] = text + logging.debug(f"Loaded transcription: {utterance_id}") + except Exception as e: + logging.warning(f"Failed to read {trn_file}: {e}") + + logging.info(f"Loaded {len(transcriptions)} transcriptions") + + # Create recordings and supervisions + recordings = [] + supervisions = [] + + for wav_file in wav_files: + # Extract utterance ID from filename (remove .CH0, etc.) + utterance_id = wav_file.stem + if '.CH' in utterance_id: + utterance_id = utterance_id.split('.CH')[0] + + # Skip if no transcription + if utterance_id not in transcriptions: + logging.warning(f"No transcription for {utterance_id}") + continue + + try: + # Create recording + recording = Recording.from_file(wav_file) + recording = Recording( + id=utterance_id, + sources=recording.sources, + sampling_rate=recording.sampling_rate, + num_samples=recording.num_samples, + duration=recording.duration, + channel_ids=recording.channel_ids, + transforms=recording.transforms + ) + recordings.append(recording) + + # Create supervision + text = transcriptions[utterance_id] + supervision = SupervisionSegment( + id=utterance_id, + recording_id=utterance_id, + start=0.0, + duration=recording.duration, + channel=0, + text=text, + language="English" + ) + supervisions.append(supervision) + + logging.info(f"Processed: {utterance_id} - '{text[:50]}...'") + + except Exception as e: + logging.warning(f"Failed to process {wav_file}: {e}") + continue + + if not recordings: + raise ValueError("No valid recordings found!") + + # Create manifests + recording_set = RecordingSet.from_recordings(recordings) + supervision_set = SupervisionSet.from_segments(supervisions) + cuts = CutSet.from_manifests(recordings=recording_set, supervisions=supervision_set) + + logging.info(f"Created {len(cuts)} cuts for CHiME-4 test") + + # Create dataset and dataloader + dataset = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=True + ) + + # Simple sampler - no bucketing for test + from lhotse.dataset import SimpleCutSampler + sampler = SimpleCutSampler(cuts, max_duration=30.0, shuffle=False) + + dataloader = DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=1 + ) + + logging.info(f"Created CHiME-4 test dataloader with {len(cuts)} utterances") + return dataloader, cuts + + +def main(): + parser = argparse.ArgumentParser(description="Create simple CHiME-4 test dataloader") + parser.add_argument("--max-files", type=int, default=10, help="Max files to process") + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + try: + dataloader, cuts = create_simple_chime4_test_loader(max_files=args.max_files) + + # Test the dataloader + logging.info("Testing dataloader...") + for i, batch in enumerate(dataloader): + if i >= 2: # Just test first 2 batches + break + logging.info(f"Batch {i}: {batch['supervisions']['text']}") + + logging.info("CHiME-4 test dataloader creation successful!") + + except Exception as e: + logging.error(f"Failed to create CHiME-4 test dataloader: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/egs/librispeech/ASR/debug_import.py b/egs/librispeech/ASR/debug_import.py new file mode 100755 index 000000000..78a770142 --- /dev/null +++ b/egs/librispeech/ASR/debug_import.py @@ -0,0 +1,89 @@ +# #!/usr/bin/env python3 +# import sys +# import os + +# print("1. 기본 모듈 임포트") +# import torch +# import k2 + +# print("2. 현재 경로에 추가") +# sys.path.insert(0, os.getcwd()) + +# print("3. train.py에서 필요한 함수들 임포트") +# try: +# from conformer_ctc.train import get_parser +# print("get_parser 성공") +# except Exception as e: +# print(f"get_parser 실패: {e}") +# sys.exit(1) + +# print("4. 파서 생성 및 인수 파싱") +# try: +# parser = get_parser() +# args = parser.parse_args(['--full-libri', 'True', '--num-epochs', '1', '--world-size', '1', '--att-rate', '0.0', '--device', 'cpu']) +# print("인수 파싱 성공") +# print(f"args: {args}") +# except Exception as e: +# print(f"인수 파싱 실패: {e}") +# sys.exit(1) + +# print("5. 데이터 모듈 임포트") +# try: +# from conformer_ctc.train import LibriSpeechAsrDataModule +# print("LibriSpeechAsrDataModule 임포트 성공") +# except Exception as e: +# print(f"LibriSpeechAsrDataModule 임포트 실패: {e}") +# sys.exit(1) + +# print("6. main 함수 임포트") +# try: +# from conformer_ctc.train import main +# print("main 함수 임포트 성공") +# except Exception as e: +# print(f"main 함수 임포트 실패: {e}") +# sys.exit(1) + +# print("7. run 함수의 처음 부분만 실행") +# try: +# from conformer_ctc.train import run +# print("run 함수 임포트 성공") +# # 실제로 실행하지는 않고 import만 확인 +# except Exception as e: +# print(f"run 함수 임포트 실패: {e}") +# sys.exit(1) + +# print("모든 단계 통과 - 디버깅 완료") + + +import sys + +# 각 임포트를 개별적으로 시도 +imports = [ + "import torch", + "import k2", + "from typing import Optional, Tuple", + "from pathlib import Path", + "from conformer_ctc.conformer import Conformer", + "import sentencepiece as spm", + "from icefall.utils import AttributeDict", + "from icefall.checkpoint import load_checkpoint", + "from icefall.dist import cleanup_dist, setup_dist", + "from asr_datamodule import LibriSpeechAsrDataModule", + "from icefall.env import get_env_info", + "from icefall.lexicon import Lexicon", + "from icefall.utils import AttributeDict", + "from icefall.utils import load_averaged_model", + "from icefall.utils import MetricsTracker", + "from icefall.utils import encode_supervisions", + "from icefall.utils import setup_logger", + "from icefall.utils import str2bool", +] + +for i, imp in enumerate(imports): + print(f"시도 {i+1}: {imp}") + try: + exec(imp) + print(f"✅ 성공: {imp}") + except Exception as e: + print(f"❌ 실패: {imp} - {e}") + break \ No newline at end of file diff --git a/egs/librispeech/ASR/decode.sh b/egs/librispeech/ASR/decode.sh new file mode 100644 index 000000000..cc2759791 --- /dev/null +++ b/egs/librispeech/ASR/decode.sh @@ -0,0 +1,11 @@ +if [ -z "${PYTHONPATH:-}" ]; then + export PYTHONPATH="/tmp/icefall" +else + export PYTHONPATH="${PYTHONPATH}:/tmp/icefall" +fi + +CUDA_VISIBLE_DEVICES=3 python ./conformer_ctc/decode.py \ + --method ctc-decoding \ + --max-duration 10 \ + --epoch 77 \ + --avg 10 diff --git a/egs/librispeech/ASR/evaluate_chime4.py b/egs/librispeech/ASR/evaluate_chime4.py new file mode 100644 index 000000000..8aae1a6e5 --- /dev/null +++ b/egs/librispeech/ASR/evaluate_chime4.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +""" +Evaluate trained conformer_ctc model on CHiME-4 dataset. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List + +import torch +from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule +from conformer_ctc.conformer import Conformer + + +def setup_logging(args): + """Setup logging configuration.""" + log_level = getattr(logging, args.log_level.upper()) + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + ) + + +def load_model(checkpoint_path: Path, device: torch.device): + """Load trained conformer model from checkpoint.""" + logging.info(f"Loading model from {checkpoint_path}") + + # Load checkpoint + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Extract model parameters from checkpoint + params = checkpoint.get('params', {}) + + # Create model with parameters from checkpoint + model = Conformer( + num_features=params.get('num_features', 80), + nhead=params.get('nhead', 8), + d_model=params.get('d_model', 512), + num_classes=params.get('num_classes', 5000), # Adjust based on your vocab + subsampling_factor=params.get('subsampling_factor', 4), + num_decoder_layers=params.get('num_decoder_layers', 0), + vgg_frontend=params.get('vgg_frontend', False), + num_encoder_layers=params.get('num_encoder_layers', 12), + att_rate=params.get('att_rate', 0.0), + # Add other parameters as needed + ) + + # Load state dict + if 'model' in checkpoint: + model.load_state_dict(checkpoint['model']) + elif 'state_dict' in checkpoint: + model.load_state_dict(checkpoint['state_dict']) + else: + model.load_state_dict(checkpoint) + + model = model.to(device) + model.eval() + + logging.info(f"Model loaded successfully with {sum(p.numel() for p in model.parameters())} parameters") + return model + + +def evaluate_chime4(model, datamodule, device: torch.device): + """Evaluate model on CHiME-4 test sets.""" + from conformer_ctc.decode import greedy_search + + # Get CHiME-4 test dataloaders + test_loaders = datamodule.chime4_test_dataloaders() + + if not test_loaders: + logging.error("No CHiME-4 test dataloaders found!") + return {} + + results = {} + + for test_set_name, test_loader in test_loaders.items(): + logging.info(f"Evaluating on CHiME-4 {test_set_name}") + + total_num_tokens = 0 + total_num_errors = 0 + + with torch.no_grad(): + for batch_idx, batch in enumerate(test_loader): + if batch_idx % 10 == 0: + logging.info(f"Processing batch {batch_idx} of {test_set_name}") + + feature = batch["inputs"].to(device) + # Convert supervisions to expected format + supervisions = batch["supervisions"] + + # Forward pass + encoder_out, encoder_out_lens = model.encode(feature, supervisions) + + # Greedy decoding + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + # Calculate WER (simplified - you may want to use proper WER calculation) + for i, hyp in enumerate(hyps): + ref_tokens = supervisions["text"][i].split() + hyp_tokens = hyp.split() + + total_num_tokens += len(ref_tokens) + # Simple edit distance calculation (you may want to use proper edit distance) + errors = abs(len(ref_tokens) - len(hyp_tokens)) + total_num_errors += errors + + if batch_idx == 0 and i == 0: # Print first example + logging.info(f"Reference: {supervisions['text'][i]}") + logging.info(f"Hypothesis: {hyp}") + + # Calculate WER + wer = total_num_errors / total_num_tokens if total_num_tokens > 0 else 1.0 + results[test_set_name] = {"WER": wer, "total_tokens": total_num_tokens} + + logging.info(f"{test_set_name} WER: {wer:.4f} ({total_num_errors}/{total_num_tokens})") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate conformer CTC on CHiME-4") + parser.add_argument( + "--checkpoint", type=Path, required=True, help="Path to model checkpoint" + ) + parser.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with manifests", + ) + parser.add_argument( + "--max-duration", type=float, default=200.0, help="Max duration for batching" + ) + parser.add_argument( + "--log-level", type=str, default="INFO", help="Logging level" + ) + parser.add_argument( + "--device", type=str, default="cuda", help="Device to use (cuda/cpu)" + ) + + args = parser.parse_args() + + setup_logging(args) + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + logging.info(f"Using device: {device}") + + # Load model + model = load_model(args.checkpoint, device) + + # Create data module + datamodule = LibriSpeechAsrDataModule(args) + + # Evaluate on CHiME-4 + results = evaluate_chime4(model, datamodule, device) + + # Print summary + logging.info("=" * 50) + logging.info("CHiME-4 Evaluation Results:") + for test_set, result in results.items(): + logging.info(f"{test_set}: WER = {result['WER']:.4f}") + logging.info("=" * 50) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/evaluate_chime4.sh b/egs/librispeech/ASR/evaluate_chime4.sh new file mode 100755 index 000000000..5a39d26dd --- /dev/null +++ b/egs/librispeech/ASR/evaluate_chime4.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# CHiME-4 evaluation script for conformer_ctc + +set -euo pipefail + +# Configuration +CHECKPOINT_PATH="conformer_ctc/exp/pretrained.pt" # Update with your actual checkpoint +LOG_LEVEL="INFO" +DEVICE="cuda" + +echo "CHiME-4 Evaluation for Conformer CTC" +echo "=====================================" + +# Check if checkpoint exists +if [ ! -f "$CHECKPOINT_PATH" ]; then + echo "Error: Checkpoint not found at $CHECKPOINT_PATH" + echo "Please train a model first or specify correct checkpoint path" + exit 1 +fi + +# Check if CHiME-4 data exists +if [ ! -d "/home/nas/DB/CHiME4/data/audio/16kHz/isolated" ]; then + echo "Error: CHiME-4 data not found at /home/nas/DB/CHiME4/data/audio/16kHz/isolated" + echo "Please check CHiME-4 data path" + exit 1 +fi + +echo "Starting CHiME-4 evaluation..." +echo "Checkpoint: $CHECKPOINT_PATH" +echo "Device: $DEVICE" +echo "" + +# Run evaluation +python evaluate_chime4.py \ + --checkpoint "$CHECKPOINT_PATH" \ + --manifest-dir data/fbank \ + --max-duration 200.0 \ + --log-level "$LOG_LEVEL" \ + --device "$DEVICE" + +echo "" +echo "CHiME-4 evaluation completed!" diff --git a/egs/librispeech/ASR/hf_upload_guide.sh b/egs/librispeech/ASR/hf_upload_guide.sh new file mode 100755 index 000000000..ad5d28048 --- /dev/null +++ b/egs/librispeech/ASR/hf_upload_guide.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +echo "허깅페이스 모델 업로드 가이드" +echo "================================" +echo "" +echo "1. 허깅페이스 토큰 생성 및 로그인:" +echo " huggingface-cli login" +echo "" +echo "2. 모델 업로드 실행:" +echo " python upload_to_huggingface.py" +echo "" +echo "3. 필요한 정보:" +echo " - Repository name (예: jenny/icefall-conformer-ctc-librispeech)" +echo " - Hugging Face token (Write 권한 필요)" +echo " - Private repository 여부" +echo "" +echo "4. 업로드될 파일들:" +echo " - best-valid-loss.pt (모델 체크포인트)" +echo " - README.md (모델 카드)" +echo " - config.json (모델 설정)" +echo " - inference_example.py (사용 예제)" +echo " - requirements.txt (의존성 패키지)" +echo "" +echo "업로드를 시작하려면:" +echo "python upload_to_huggingface.py" diff --git a/egs/librispeech/ASR/local/__init__.py b/egs/librispeech/ASR/local/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/local/compute_fbank_rir.py b/egs/librispeech/ASR/local/compute_fbank_rir.py new file mode 100644 index 000000000..24be6376f --- /dev/null +++ b/egs/librispeech/ASR/local/compute_fbank_rir.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +""" +This file computes fbank features of the RIR dataset. +It looks for RIR recordings and generates fbank features. + +The generated fbank features are saved in data/fbank. +""" +import argparse +import logging +import os +from pathlib import Path + +import torch +import soundfile as sf +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + MonoCut, + RecordingSet, + Recording, +) +from lhotse.audio import AudioSource + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_rir( + rir_scp: str = "data/manifests/rir.scp", + num_mel_bins: int = 80, + output_dir: str = "data/fbank", + max_files: int = None +): + """ + Compute fbank features for RIR files. + + Args: + rir_scp: Path to rir.scp file + num_mel_bins: Number of mel filter banks + output_dir: Output directory for features + max_files: Maximum number of RIR files to process (for testing) + """ + output_dir = Path(output_dir) + num_jobs = min(15, os.cpu_count()) + + rir_cuts_path = output_dir / "rir_cuts.jsonl.gz" + + if rir_cuts_path.is_file(): + logging.info(f"{rir_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for RIR") + + # Create RIR recordings from scp file + recordings = [] + with open(rir_scp, 'r') as f: + for idx, line in enumerate(f): + if max_files and idx >= max_files: + break + + rir_path = Path(line.strip()) + if not rir_path.exists(): + logging.warning(f"RIR file not found: {rir_path}") + continue + + rir_id = f"rir_{idx:06d}" + + try: + # Get audio info using soundfile + with sf.SoundFile(rir_path) as audio_file: + sampling_rate = audio_file.samplerate + num_samples = len(audio_file) + duration = num_samples / sampling_rate + + # Create recording with proper metadata + recording = Recording( + id=rir_id, + sources=[ + AudioSource( + type="file", + channels=[0], + source=str(rir_path.resolve()), + ) + ], + sampling_rate=int(sampling_rate), + num_samples=int(num_samples), + duration=float(duration), + ) + recordings.append(recording) + + except Exception as e: + logging.warning(f"Failed to process {rir_path}: {e}") + continue + + if (idx + 1) % 1000 == 0: + logging.info(f"Processed {idx + 1} RIR files...") + + logging.info(f"Found {len(recordings)} RIR files") + + # Create recording set + rir_recordings = RecordingSet.from_recordings(recordings) + + # Feature extractor + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: + # Create cuts and compute features + rir_cuts = ( + CutSet.from_manifests(recordings=rir_recordings) + .compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/rir_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + ) + rir_cuts.to_file(rir_cuts_path) + + logging.info(f"Saved RIR cuts with features to {rir_cuts_path}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--rir-scp", + type=str, + default="data/manifests/rir.scp", + help="Path to rir.scp file. Default: data/manifests/rir.scp", + ) + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="The number of mel bins for Fbank. Default: 80", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank", + ) + parser.add_argument( + "--max-files", + type=int, + default=None, + help="Maximum number of RIR files to process (for testing). Default: None (process all)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + compute_fbank_rir( + rir_scp=args.rir_scp, + num_mel_bins=args.num_mel_bins, + output_dir=args.output_dir, + max_files=args.max_files, + ) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/local/prepare_rir.py b/egs/librispeech/ASR/local/prepare_rir.py new file mode 100644 index 000000000..1112275d1 --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_rir.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +""" +Prepare RIR (Room Impulse Response) data for lhotse. +This script converts rir.scp file to lhotse manifest format. +""" + +import argparse +import logging +from pathlib import Path +from typing import List + +from lhotse import CutSet, Recording, SupervisionSegment +from lhotse.audio import AudioSource +from lhotse.utils import Pathlike + +def get_args(): + parser = argparse.ArgumentParser( + description="Prepare RIR data for lhotse", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--rir-scp", + type=Path, + required=True, + help="Path to rir.scp file containing RIR file paths", + ) + + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for RIR manifests", + ) + + return parser.parse_args() + + +def prepare_rir_manifest( + rir_scp: Pathlike, + output_dir: Pathlike, +) -> None: + """ + Prepare RIR manifest from rir.scp file. + + Args: + rir_scp: Path to rir.scp file + output_dir: Output directory for manifests + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + recordings = [] + + # Read rir.scp file + with open(rir_scp, 'r') as f: + for line_idx, line in enumerate(f): + line = line.strip() + if not line or line.startswith('#'): + continue + + # Parse line: either "path" or "id path" + parts = line.split() + if len(parts) == 1: + rir_path = parts[0] + rir_id = f"rir_{line_idx:06d}" + elif len(parts) == 2: + rir_id, rir_path = parts + else: + logging.warning(f"Invalid line in rir.scp: {line}") + continue + + # Check if file exists + rir_path = Path(rir_path) + if not rir_path.exists(): + logging.warning(f"RIR file not found: {rir_path}") + continue + + # Create recording + recording = Recording( + id=rir_id, + sources=[ + AudioSource( + type="file", + channels=[0], + source=str(rir_path.resolve()), + ) + ], + sampling_rate=16000, # Assume 16kHz, will be auto-detected by lhotse + num_samples=None, # Will be auto-detected + duration=None, # Will be auto-detected + ) + + recordings.append(recording) + + logging.info(f"Found {len(recordings)} RIR files") + + # Create recording set and save + from lhotse import RecordingSet + recording_set = RecordingSet.from_recordings(recordings) + + # Validate recordings (this will auto-detect duration, sampling_rate, etc.) + logging.info("Validating RIR recordings...") + recording_set = recording_set.with_path_prefix("") # Ensure absolute paths + + # Save recording manifest + output_path = output_dir / "rir_recordings.jsonl.gz" + recording_set.to_file(output_path) + logging.info(f"Saved RIR recording manifest to {output_path}") + + # Create a simple cuts manifest for RIR (whole files) + logging.info("Creating RIR cuts manifest...") + rir_cuts = CutSet.from_manifests(recordings=recording_set) + cuts_output_path = output_dir / "rir_cuts.jsonl.gz" + rir_cuts.to_file(cuts_output_path) + logging.info(f"Saved RIR cuts manifest to {cuts_output_path}") + + return recording_set, rir_cuts + + +def main(): + args = get_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + ) + + logging.info("Preparing RIR data...") + prepare_rir_manifest( + rir_scp=args.rir_scp, + output_dir=args.output_dir, + ) + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/local/prepare_rir_fixed.py b/egs/librispeech/ASR/local/prepare_rir_fixed.py new file mode 100644 index 000000000..ee87dc41c --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_rir_fixed.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +""" +Fixed version of prepare RIR data for lhotse. +This script converts rir.scp file to lhotse manifest format. +""" + +import argparse +import logging +from pathlib import Path +from typing import List +import json +import gzip + +from lhotse import CutSet, Recording, RecordingSet +from lhotse.audio import AudioSource +from lhotse.utils import Pathlike + +def get_args(): + parser = argparse.ArgumentParser( + description="Prepare RIR data for lhotse", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--rir-scp", + type=Path, + required=True, + help="Path to rir.scp file containing RIR file paths", + ) + + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for RIR manifests", + ) + + return parser.parse_args() + + +def prepare_rir_manifest( + rir_scp: Pathlike, + output_dir: Pathlike, +) -> None: + """ + Prepare RIR manifest from rir.scp file. + + Args: + rir_scp: Path to rir.scp file + output_dir: Output directory for manifests + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + recordings = [] + + # Read rir.scp file + with open(rir_scp, 'r') as f: + for line_idx, line in enumerate(f): + line = line.strip() + if not line or line.startswith('#'): + continue + + # Parse line: either "path" or "id path" + parts = line.split() + if len(parts) == 1: + rir_path = parts[0] + rir_id = f"rir_{line_idx:06d}" + elif len(parts) == 2: + rir_id, rir_path = parts + else: + logging.warning(f"Invalid line in rir.scp: {line}") + continue + + # Check if file exists + rir_path = Path(rir_path) + if not rir_path.exists(): + logging.warning(f"RIR file not found: {rir_path}") + continue + + # Create recording + recording = Recording( + id=rir_id, + sources=[ + AudioSource( + type="file", + channels=[0], + source=str(rir_path.resolve()), + ) + ], + sampling_rate=16000, # Assume 16kHz, will be auto-detected by lhotse + num_samples=None, # Will be auto-detected + duration=None, # Will be auto-detected + ) + + recordings.append(recording) + + logging.info(f"Found {len(recordings)} RIR files") + + # Create recording set and save + recording_set = RecordingSet.from_recordings(recordings) + + # Validate recordings (this will auto-detect duration, sampling_rate, etc.) + logging.info("Validating RIR recordings...") + + # Save recording manifest + output_path = output_dir / "rir_recordings.jsonl.gz" + recording_set.to_file(output_path) + logging.info(f"Saved RIR recording manifest to {output_path}") + + # Create cuts manually to ensure correct format + logging.info("Creating RIR cuts manifest...") + cuts_data = [] + + for recording in recording_set: + cut_data = { + "id": f"{recording.id}-0", + "start": 0, + "duration": recording.duration, + "channel": 0, + "recording": recording.to_dict() + } + cuts_data.append(cut_data) + + # Save cuts manually + cuts_output_path = output_dir / "rir_cuts.jsonl.gz" + with gzip.open(cuts_output_path, 'wt') as f: + for cut in cuts_data: + f.write(json.dumps(cut) + '\n') + + logging.info(f"Saved RIR cuts manifest to {cuts_output_path}") + + # Verify the cuts can be loaded + try: + from lhotse import load_manifest + cuts_test = load_manifest(cuts_output_path) + logging.info(f"Successfully verified: loaded {len(cuts_test)} cuts") + except Exception as e: + logging.error(f"Failed to verify cuts: {e}") + + return recording_set + + +def main(): + args = get_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + ) + + logging.info("Preparing RIR data...") + prepare_rir_manifest( + rir_scp=args.rir_scp, + output_dir=args.output_dir, + ) + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/local/prepare_rir_minimal.py b/egs/librispeech/ASR/local/prepare_rir_minimal.py new file mode 100644 index 000000000..6d17feb3c --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_rir_minimal.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Simple approach: create minimal RIR cuts without extra validation. +""" + +import argparse +import logging +from pathlib import Path +import json +import gzip +import soundfile as sf + +def get_args(): + parser = argparse.ArgumentParser( + description="Create minimal RIR cuts manifest", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--rir-scp", + type=Path, + required=True, + help="Path to rir.scp file containing RIR file paths", + ) + + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for RIR manifests", + ) + + parser.add_argument( + "--max-files", + type=int, + default=1000, + help="Maximum number of RIR files to process (for testing)", + ) + + return parser.parse_args() + + +def create_minimal_rir_cuts( + rir_scp: Path, + output_dir: Path, + max_files: int = 1000 +) -> None: + """ + Create a minimal RIR cuts manifest. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + cuts_data = [] + recordings_data = [] + + # Read rir.scp file (limited for testing) + with open(rir_scp, 'r') as f: + for line_idx, line in enumerate(f): + if line_idx >= max_files: + break + + line = line.strip() + if not line or line.startswith('#'): + continue + + rir_path = Path(line.strip()) + if not rir_path.exists(): + logging.warning(f"RIR file not found: {rir_path}") + continue + + rir_id = f"rir_{line_idx:06d}" + + try: + # Get audio info + with sf.SoundFile(rir_path) as audio: + sampling_rate = audio.samplerate + num_samples = len(audio) + duration = num_samples / sampling_rate + + # Create recording entry + recording = { + "id": rir_id, + "sources": [{ + "type": "file", + "channels": [0], + "source": str(rir_path.resolve()) + }], + "sampling_rate": int(sampling_rate), + "num_samples": int(num_samples), + "duration": float(duration), + "channel_ids": [0] + } + recordings_data.append(recording) + + # Create cut entry + cut = { + "id": f"{rir_id}-0", + "start": 0.0, + "duration": float(duration), + "channel": 0, + "recording_id": rir_id + } + cuts_data.append(cut) + + if (line_idx + 1) % 100 == 0: + logging.info(f"Processed {line_idx + 1} RIR files...") + + except Exception as e: + logging.warning(f"Failed to process {rir_path}: {e}") + continue + + logging.info(f"Successfully processed {len(cuts_data)} RIR files") + + # Save recordings manifest + recordings_path = output_dir / "rir_recordings.jsonl.gz" + with gzip.open(recordings_path, 'wt') as f: + for recording in recordings_data: + f.write(json.dumps(recording) + '\n') + logging.info(f"Saved recordings to {recordings_path}") + + # Save cuts manifest + cuts_path = output_dir / "rir_cuts.jsonl.gz" + with gzip.open(cuts_path, 'wt') as f: + for cut in cuts_data: + f.write(json.dumps(cut) + '\n') + logging.info(f"Saved cuts to {cuts_path}") + + # Test loading + try: + from lhotse import load_manifest + cuts_test = load_manifest(cuts_path) + recordings_test = load_manifest(recordings_path) + logging.info(f"✓ Successfully verified: {len(cuts_test)} cuts, {len(recordings_test)} recordings") + except Exception as e: + logging.error(f"✗ Verification failed: {e}") + + +def main(): + args = get_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + ) + + logging.info(f"Creating minimal RIR manifest (max {args.max_files} files)...") + create_minimal_rir_cuts( + rir_scp=args.rir_scp, + output_dir=args.output_dir, + max_files=args.max_files + ) + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/local/prepare_rir_standard.py b/egs/librispeech/ASR/local/prepare_rir_standard.py new file mode 100644 index 000000000..bfc34599d --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_rir_standard.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +Create RIR cuts using lhotse's standard approach. +This should create a properly formatted cuts manifest. +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, Recording, RecordingSet +from lhotse.audio import AudioSource +from lhotse.utils import Pathlike + +def get_args(): + parser = argparse.ArgumentParser( + description="Prepare RIR data for lhotse using standard approach", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--rir-scp", + type=Path, + required=True, + help="Path to rir.scp file containing RIR file paths", + ) + + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for RIR manifests", + ) + + return parser.parse_args() + + +def prepare_rir_manifest( + rir_scp: Pathlike, + output_dir: Pathlike, +) -> None: + """ + Prepare RIR manifest using lhotse's standard approach. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + recordings = [] + + # Read rir.scp file + with open(rir_scp, 'r') as f: + for line_idx, line in enumerate(f): + line = line.strip() + if not line or line.startswith('#'): + continue + + # Parse line: either "path" or "id path" + parts = line.split() + if len(parts) == 1: + rir_path = parts[0] + rir_id = f"rir_{line_idx:06d}" + elif len(parts) == 2: + rir_id, rir_path = parts + else: + logging.warning(f"Invalid line in rir.scp: {line}") + continue + + # Check if file exists + rir_path = Path(rir_path) + if not rir_path.exists(): + logging.warning(f"RIR file not found: {rir_path}") + continue + + # Create recording + recording = Recording( + id=rir_id, + sources=[ + AudioSource( + type="file", + channels=[0], + source=str(rir_path.resolve()), + ) + ], + sampling_rate=16000, # Will be auto-detected + num_samples=None, # Will be auto-detected + duration=None, # Will be auto-detected + ) + + recordings.append(recording) + + logging.info(f"Found {len(recordings)} RIR files") + + # Create recording set and validate + recording_set = RecordingSet.from_recordings(recordings) + + # Save recording manifest + recordings_output_path = output_dir / "rir_recordings.jsonl.gz" + recording_set.to_file(recordings_output_path) + logging.info(f"Saved RIR recording manifest to {recordings_output_path}") + + # Create cuts using lhotse's standard method + logging.info("Creating RIR cuts manifest using lhotse's standard method...") + cuts = CutSet.from_manifests(recordings=recording_set) + + # Save cuts manifest + cuts_output_path = output_dir / "rir_cuts.jsonl.gz" + cuts.to_file(cuts_output_path) + logging.info(f"Saved RIR cuts manifest to {cuts_output_path}") + + # Verify the cuts can be loaded + try: + from lhotse import load_manifest + cuts_test = load_manifest(cuts_output_path) + logging.info(f"✓ Successfully verified: loaded {len(cuts_test)} cuts") + logging.info(f"First cut ID: {cuts_test[0].id}") + logging.info(f"First cut keys: {list(cuts_test[0].to_dict().keys())}") + except Exception as e: + logging.error(f"✗ Failed to verify cuts: {e}") + + # Try CutSet.from_file as fallback + try: + cuts_test2 = CutSet.from_file(cuts_output_path) + logging.info(f"✓ CutSet.from_file worked: loaded {len(cuts_test2)} cuts") + except Exception as e2: + logging.error(f"✗ CutSet.from_file also failed: {e2}") + + return recording_set, cuts + + +def main(): + args = get_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + ) + + logging.info("Preparing RIR data using lhotse standard approach...") + prepare_rir_manifest( + rir_scp=args.rir_scp, + output_dir=args.output_dir, + ) + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/local/simple_rir.py b/egs/librispeech/ASR/local/simple_rir.py new file mode 100644 index 000000000..7ca50828e --- /dev/null +++ b/egs/librispeech/ASR/local/simple_rir.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +""" +Super simple RIR cuts creator - manual approach without complex lhotse logic +""" + +import argparse +import logging +from pathlib import Path +import json +import gzip +import wave +import soundfile as sf + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--rir-scp", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--max-files", type=int, default=1000) + return parser.parse_args() + +def main(): + args = get_args() + + logging.basicConfig(level=logging.INFO) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + recordings = [] + cuts = [] + + with open(args.rir_scp, 'r') as f: + for idx, line in enumerate(f): + if idx >= args.max_files: + break + + rir_path = Path(line.strip()) + if not rir_path.exists(): + continue + + try: + # Use soundfile to get audio info + info = sf.info(rir_path) + duration = info.duration + sampling_rate = info.samplerate + num_samples = info.frames + + rir_id = f"rir_{idx:06d}" + + # Recording entry - same format as LibriSpeech + recording = { + "id": rir_id, + "sources": [{ + "type": "file", + "channels": [0], + "source": str(rir_path.resolve()) + }], + "sampling_rate": int(sampling_rate), + "num_samples": int(num_samples), + "duration": float(duration), + "channel_ids": [0] + } + recordings.append(recording) + + # Cut entry - same format as LibriSpeech + cut = { + "id": f"{rir_id}-0", + "start": 0.0, + "duration": float(duration), + "channel": 0, + "recording_id": rir_id + } + cuts.append(cut) + + if (idx + 1) % 100 == 0: + logging.info(f"Processed {idx + 1} files...") + + except Exception as e: + logging.warning(f"Failed {rir_path}: {e}") + continue + + logging.info(f"Created {len(recordings)} recordings and {len(cuts)} cuts") + + # Save files + rec_path = args.output_dir / "rir_recordings.jsonl.gz" + with gzip.open(rec_path, 'wt') as f: + for rec in recordings: + f.write(json.dumps(rec) + '\n') + + cuts_path = args.output_dir / "rir_cuts.jsonl.gz" + with gzip.open(cuts_path, 'wt') as f: + for cut in cuts: + f.write(json.dumps(cut) + '\n') + + logging.info(f"Saved to {rec_path} and {cuts_path}") + + # Test loading + try: + from lhotse import load_manifest + test_cuts = load_manifest(cuts_path) + test_recs = load_manifest(rec_path) + logging.info(f"✓ SUCCESS: {len(test_cuts)} cuts, {len(test_recs)} recordings loaded!") + except Exception as e: + logging.error(f"✗ FAILED: {e}") + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/long_file_recog/asr_datamodule.py b/egs/librispeech/ASR/long_file_recog/asr_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/long_file_recog/beam_search.py b/egs/librispeech/ASR/long_file_recog/beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/test_scaling_converter.py b/egs/librispeech/ASR/lstm_transducer_stateless/test_scaling_converter.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/__init__.py b/egs/librispeech/ASR/lstm_transducer_stateless2/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/test_scaling_converter.py b/egs/librispeech/ASR/lstm_transducer_stateless3/test_scaling_converter.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/minimal_train.sh b/egs/librispeech/ASR/minimal_train.sh new file mode 100644 index 000000000..4cbb0b3b1 --- /dev/null +++ b/egs/librispeech/ASR/minimal_train.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +# minimal_train.sh - 최소한의 설정으로 안정적인 훈련 +set -euo pipefail + +# 매우 보수적인 설정 +world_size=1 # 단일 GPU로 시작 +max_duration=100 # 매우 작은 배치 크기 +valid_max_duration=10 +num_buckets=50 +num_workers=2 +warm_step=10000 +lang_dir="./data/lang_bpe_5000" +method="ctc-decoding" + +# Model parameters +att_rate=0 +num_decoder_layers=0 + +# Other settings +start_epoch=19 +master_port=12346 +sanity_check=false + +# Validation 완전히 비활성화 +enable_validation=false + +if [ -z "${PYTHONPATH:-}" ]; then + export PYTHONPATH="/tmp/icefall" +else + export PYTHONPATH="${PYTHONPATH}:/tmp/icefall" +fi + +echo "🚀 Starting minimal stable training..." +echo "World size: $world_size" +echo "Max duration: $max_duration" +echo "Validation: $enable_validation" + +python3 ./conformer_ctc/train.py \ + --master-port $master_port \ + --sanity-check $sanity_check \ + --world-size $world_size \ + --warm-step $warm_step \ + --start-epoch $start_epoch \ + --att-rate $att_rate \ + --num-decoder-layers $num_decoder_layers \ + --num-workers $num_workers \ + --enable-spec-aug false \ + --enable-musan false \ + --enable-rir false \ + --rir-cuts-path data/rir/rir_cuts.jsonl.gz \ + --rir-prob 0.5 \ + --max-duration $max_duration \ + --valid-max-duration $valid_max_duration \ + --num-buckets $num_buckets \ + --bucketing-sampler true \ + --concatenate-cuts false \ + --duration-factor 1.0 \ + --drop-last true \ + --shuffle true \ + --lang-dir $lang_dir \ + --method $method \ + --enable-validation $enable_validation diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 40dc3260d..5756fb4f5 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -5,10 +5,10 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -nj=15 +nj=30 # run step 0 to step 5 by default -stage=0 -stop_stage=5 +stage=-1 # 1 +stop_stage=-1 # 6 # Note: This script just prepare the minimal requirements that needed by a # transducer training with bpe units. @@ -54,7 +54,7 @@ stop_stage=5 # - librispeech-lexicon.txt # - librispeech-lm-norm.txt.gz -dl_dir=$PWD/download +dl_dir=/home/hdd1/jenny . shared/parse_options.sh || exit 1 @@ -62,10 +62,10 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - # 5000 + 5000 # 2000 # 1000 - 500 + # 500 ) # All files generated by this script are saved in "data". @@ -119,7 +119,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/LibriSpeech mkdir -p data/manifests if [ ! -e data/manifests/.librispeech.done ]; then - lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests + lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech/LibriSpeech data/manifests touch data/manifests/.librispeech.done fi fi @@ -130,7 +130,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # to $dl_dir/musan mkdir -p data/manifests if [ ! -e data/manifests/.musan.done ]; then - lhotse prepare musan $dl_dir/musan data/manifests + lhotse prepare musan /home/nas3/DB/musan data/manifests touch data/manifests/.musan.done fi fi diff --git a/egs/librispeech/ASR/prepare_chime4.py b/egs/librispeech/ASR/prepare_chime4.py new file mode 100644 index 000000000..978f2b659 --- /dev/null +++ b/egs/librispeech/ASR/prepare_chime4.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +Prepare CHiME-4 dataset for icefall ASR experiments. +Creates lhotse manifests for CHiME-4 audio and supervision data. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Dict, List, Tuple + +from lhotse import CutSet, Recording, RecordingSet, SupervisionSegment, SupervisionSet +from lhotse.recipes.utils import read_manifests_if_cached + + +def get_chime4_audio_paths(audio_root: Path) -> Dict[str, List[Path]]: + """Get all CHiME-4 audio file paths organized by subset.""" + audio_paths = {} + + # Define CHiME-4 subsets + subsets = [ + 'dt05_bth', 'dt05_bus_real', 'dt05_bus_simu', 'dt05_caf_real', 'dt05_caf_simu', + 'dt05_ped_real', 'dt05_ped_simu', 'dt05_str_real', 'dt05_str_simu', + 'et05_bth', 'et05_bus_real', 'et05_bus_simu', 'et05_caf_real', 'et05_caf_simu', + 'et05_ped_real', 'et05_ped_simu', 'et05_str_real', 'et05_str_simu', + 'tr05_bth', 'tr05_bus_real', 'tr05_bus_simu', 'tr05_caf_real', 'tr05_caf_simu', + 'tr05_org', 'tr05_ped_real', 'tr05_ped_simu', 'tr05_str_real', 'tr05_str_simu' + ] + + for subset in subsets: + subset_dir = audio_root / subset + if subset_dir.exists(): + wav_files = list(subset_dir.glob("*.wav")) + if wav_files: + audio_paths[subset] = wav_files + logging.info(f"Found {len(wav_files)} files in {subset}") + + return audio_paths + + +def parse_chime4_transcription_file(trn_file: Path) -> List[Tuple[str, str]]: + """Parse CHiME-4 transcription file and return list of (utterance_id, text) pairs.""" + transcriptions = [] + + with open(trn_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + # CHiME-4 transcription format: "text (utterance_id)" + if line.endswith(')') and '(' in line: + parts = line.rsplit('(', 1) + if len(parts) == 2: + text = parts[0].strip() + utterance_id = parts[1].rstrip(')').strip() + transcriptions.append((utterance_id, text)) + + return transcriptions + + +def get_chime4_transcriptions(transcript_root: Path) -> Dict[str, str]: + """Get all CHiME-4 transcriptions organized by utterance ID.""" + all_transcriptions = {} + + # Process individual subset transcription files + for trn_file in transcript_root.glob("*/*.trn"): + subset_name = trn_file.parent.name + logging.info(f"Processing transcriptions from {trn_file}") + + transcriptions = parse_chime4_transcription_file(trn_file) + for utterance_id, text in transcriptions: + all_transcriptions[utterance_id] = text + + logging.info(f"Added {len(transcriptions)} transcriptions from {subset_name}") + + # Also process .trn_all files + for trn_all_file in transcript_root.glob("*.trn_all"): + logging.info(f"Processing transcriptions from {trn_all_file}") + + transcriptions = parse_chime4_transcription_file(trn_all_file) + for utterance_id, text in transcriptions: + all_transcriptions[utterance_id] = text + + logging.info(f"Added {len(transcriptions)} transcriptions from {trn_all_file.name}") + + return all_transcriptions + + +def create_chime4_recordings(audio_paths: Dict[str, List[Path]]) -> RecordingSet: + """Create RecordingSet from CHiME-4 audio files.""" + recordings = [] + + for subset, wav_files in audio_paths.items(): + for wav_file in wav_files: + # Extract utterance ID from filename + # Example: F01_22GC010A_BTH.CH0.wav -> F01_22GC010A_BTH + utterance_id = wav_file.stem + if '.CH' in utterance_id: + utterance_id = utterance_id.split('.CH')[0] + + try: + recording = Recording.from_file(wav_file) + # Create new recording with custom ID instead of using with_id() + recording = Recording( + id=utterance_id, + sources=recording.sources, + sampling_rate=recording.sampling_rate, + num_samples=recording.num_samples, + duration=recording.duration, + channel_ids=recording.channel_ids, + transforms=recording.transforms + ) + recordings.append(recording) + except Exception as e: + logging.warning(f"Failed to process {wav_file}: {e}") + continue + + logging.info(f"Created {len(recordings)} recordings") + return RecordingSet.from_recordings(recordings) + + +def create_chime4_supervisions(transcriptions: Dict[str, str], recordings: RecordingSet) -> SupervisionSet: + """Create SupervisionSet from CHiME-4 transcriptions.""" + supervisions = [] + + for recording in recordings: + utterance_id = recording.id + if utterance_id in transcriptions: + text = transcriptions[utterance_id] + supervision = SupervisionSegment( + id=utterance_id, + recording_id=utterance_id, + start=0.0, + duration=recording.duration, + channel=0, + text=text, + language="English" + ) + supervisions.append(supervision) + else: + logging.warning(f"No transcription found for {utterance_id}") + + logging.info(f"Created {len(supervisions)} supervisions") + return SupervisionSet.from_segments(supervisions) + + +def prepare_chime4( + audio_root: Path, + transcript_root: Path, + output_dir: Path +) -> None: + """Prepare CHiME-4 dataset and save manifests.""" + + output_dir.mkdir(parents=True, exist_ok=True) + + # Get audio file paths + logging.info("Scanning for CHiME-4 audio files...") + audio_paths = get_chime4_audio_paths(audio_root) + + # Get transcriptions + logging.info("Loading CHiME-4 transcriptions...") + transcriptions = get_chime4_transcriptions(transcript_root) + logging.info(f"Loaded {len(transcriptions)} transcriptions") + + # Create recordings + logging.info("Creating recordings manifest...") + recordings = create_chime4_recordings(audio_paths) + + # Create supervisions + logging.info("Creating supervisions manifest...") + supervisions = create_chime4_supervisions(transcriptions, recordings) + + # Create cuts + logging.info("Creating cuts manifest...") + cuts = CutSet.from_manifests(recordings=recordings, supervisions=supervisions) + + # Separate by evaluation sets (dt05, et05) and training (tr05) + dt05_cuts = cuts.filter(lambda cut: cut.id.startswith('dt05') or 'dt05' in cut.id) + et05_cuts = cuts.filter(lambda cut: cut.id.startswith('et05') or 'et05' in cut.id) + tr05_cuts = cuts.filter(lambda cut: cut.id.startswith('tr05') or 'tr05' in cut.id) + + # Save manifests + logging.info("Saving manifests...") + + if len(dt05_cuts) > 0: + dt05_recordings = recordings.filter(lambda r: r.id in [c.recording.id for c in dt05_cuts]) + dt05_supervisions = supervisions.filter(lambda s: s.recording_id in [c.recording.id for c in dt05_cuts]) + + dt05_recordings.to_file(output_dir / "chime4_recordings_dt05.jsonl.gz") + dt05_supervisions.to_file(output_dir / "chime4_supervisions_dt05.jsonl.gz") + dt05_cuts.to_file(output_dir / "chime4_cuts_dt05.jsonl.gz") + logging.info(f"Saved dt05 manifests with {len(dt05_cuts)} cuts") + + if len(et05_cuts) > 0: + et05_recordings = recordings.filter(lambda r: r.id in [c.recording.id for c in et05_cuts]) + et05_supervisions = supervisions.filter(lambda s: s.recording_id in [c.recording.id for c in et05_cuts]) + + et05_recordings.to_file(output_dir / "chime4_recordings_et05.jsonl.gz") + et05_supervisions.to_file(output_dir / "chime4_supervisions_et05.jsonl.gz") + et05_cuts.to_file(output_dir / "chime4_cuts_et05.jsonl.gz") + logging.info(f"Saved et05 manifests with {len(et05_cuts)} cuts") + + if len(tr05_cuts) > 0: + tr05_recordings = recordings.filter(lambda r: r.id in [c.recording.id for c in tr05_cuts]) + tr05_supervisions = supervisions.filter(lambda s: s.recording_id in [c.recording.id for c in tr05_cuts]) + + tr05_recordings.to_file(output_dir / "chime4_recordings_tr05.jsonl.gz") + tr05_supervisions.to_file(output_dir / "chime4_supervisions_tr05.jsonl.gz") + tr05_cuts.to_file(output_dir / "chime4_cuts_tr05.jsonl.gz") + logging.info(f"Saved tr05 manifests with {len(tr05_cuts)} cuts") + + logging.info(f"CHiME-4 data preparation completed. Manifests saved to {output_dir}") + + +def main(): + parser = argparse.ArgumentParser(description="Prepare CHiME-4 dataset for icefall") + parser.add_argument( + "--audio-root", + type=Path, + default=Path("/home/nas/DB/CHiME4/data/audio/16kHz/isolated"), + help="Path to CHiME-4 audio root directory" + ) + parser.add_argument( + "--transcript-root", + type=Path, + default=Path("/home/nas/DB/CHiME4/data/transcriptions"), + help="Path to CHiME-4 transcription root directory" + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/chime4"), + help="Output directory for manifest files" + ) + + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + prepare_chime4(args.audio_root, args.transcript_root, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/prepare_rir_data.sh b/egs/librispeech/ASR/prepare_rir_data.sh new file mode 100755 index 000000000..e17d17f1e --- /dev/null +++ b/egs/librispeech/ASR/prepare_rir_data.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +# prepare_rir_data.sh +# Script to prepare RIR data for icefall training + +set -euo pipefail + +stage=0 +stop_stage=100 + +# Directories and files +rir_scp="/home/hdd2/jenny/ASRToolkit/icefall/egs/librispeech/ASR/data/manifests/rir.scp" # Path to your rir.scp file +data_dir="data/rir" +rir_cuts_manifest="$data_dir/rir_cuts.jsonl.gz" + +. shared/parse_options.sh || exit 1 + +if [ $# != 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /path/to/your/rir.scp" + echo "" + echo "Options:" + echo " --stage # Stage to start from (default: 0)" + echo " --stop-stage # Stage to stop at (default: 100)" + echo " --data-dir # Output directory (default: data/rir)" + exit 1 +fi + +rir_scp=$1 + +if [ ! -f "$rir_scp" ]; then + echo "Error: RIR scp file not found: $rir_scp" + exit 1 +fi + +log() { + echo "[$(date +'%Y-%m-%d %H:%M:%S')] $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Preparing RIR manifest from $rir_scp" + + mkdir -p $data_dir + + python local/prepare_rir.py \ + --rir-scp $rir_scp \ + --output-dir $data_dir + + log "RIR manifest saved to $rir_cuts_manifest" +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Validating RIR manifest" + + if [ ! -f "$rir_cuts_manifest" ]; then + echo "Error: RIR cuts manifest not found: $rir_cuts_manifest" + exit 1 + fi + + # Count number of RIR files + python -c " +from lhotse import load_manifest +cuts = load_manifest('$rir_cuts_manifest') +print(f'Successfully loaded {len(cuts)} RIR cuts') +print(f'Total duration: {cuts.total_duration():.2f} seconds') +print(f'Average duration: {cuts.total_duration()/len(cuts):.3f} seconds') +" + + log "RIR data preparation completed successfully!" +fi + +log "To use RIR augmentation in training, add these options:" +log " --enable-rir True" +log " --rir-cuts-path $rir_cuts_manifest" +log " --rir-prob 0.5 # Adjust probability as needed" diff --git a/egs/librispeech/ASR/pruned2_knowledge/__init__.py b/egs/librispeech/ASR/pruned2_knowledge/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py b/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/noam.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/noam.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless5/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/__init__.py b/egs/librispeech/ASR/tdnn_lstm_ctc/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py old mode 100644 new mode 100755 index 1b52aa8b5..913add06c --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -19,9 +19,10 @@ import argparse import inspect import logging +import warnings from functools import lru_cache from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy @@ -38,12 +39,58 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, OnTheFlyFeatures, ) +from lhotse.augmentation import ReverbWithImpulseResponse from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool +# Filter out RIR reverberation warnings +class RIRWarningFilter(logging.Filter): + def filter(self, record): + return not ("Attempting to reverberate" in record.getMessage() and "pre-computed features" in record.getMessage()) + +# Apply the filter to root logger +logging.getLogger().addFilter(RIRWarningFilter()) + + +class RandomRIRTransform: + """ + Random RIR (Room Impulse Response) transform that applies reverberation + to CutSet using lhotse's built-in reverb_rir method. + """ + def __init__(self, rir_paths, prob=0.5): + from lhotse import Recording, RecordingSet + # Load RIR recordings from file paths + self.rir_recordings = [] + for i, rir_path in enumerate(rir_paths[:50]): # Limit to first 50 for memory + try: + rir_rec = Recording.from_file(rir_path) + # Resample to 16kHz if needed + if rir_rec.sampling_rate != 16000: + rir_rec = rir_rec.resample(16000) + self.rir_recordings.append(rir_rec) + except Exception as e: + continue # Skip problematic files + + # Create RecordingSet from loaded recordings + if self.rir_recordings: + self.rir_recording_set = RecordingSet.from_recordings(self.rir_recordings) + else: + self.rir_recording_set = None + + self.prob = prob + print(f"Loaded {len(self.rir_recordings)} RIR recordings for augmentation") + + def __call__(self, cuts): + """Apply RIR to CutSet with specified probability.""" + import random + if random.random() < self.prob and self.rir_recording_set is not None: + # Apply reverb_rir to the entire CutSet + return cuts.reverb_rir(rir_recordings=self.rir_recording_set) + return cuts + class _SeedWorkers: def __init__(self, seed: int): self.seed = seed @@ -109,6 +156,14 @@ class LibriSpeechAsrDataModule: help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) + group.add_argument( + "--valid-max-duration", + type=int, + default=None, + help="Maximum pooled recordings duration (seconds) in a " + "single validation batch. If None, uses --max-duration. " + "You should reduce this if validation causes CUDA OOM.", + ) group.add_argument( "--bucketing-sampler", type=str2bool, @@ -186,7 +241,7 @@ class LibriSpeechAsrDataModule: group.add_argument( "--enable-spec-aug", type=str2bool, - default=True, + default=False, help="When enabled, use SpecAugment for training dataset.", ) @@ -203,11 +258,34 @@ class LibriSpeechAsrDataModule: group.add_argument( "--enable-musan", type=str2bool, - default=True, + default=False, help="When enabled, select noise from MUSAN and mix it" "with training dataset. ", ) + group.add_argument( + "--enable-rir", + type=str2bool, + default=False, + help="When enabled, convolve training data with RIR " + "(Room Impulse Response) for data augmentation.", + ) + + group.add_argument( + "--rir-cuts-path", + type=Path, + default=None, + help="Path to RIR cuts manifest file (e.g., data/rir/rir_cuts.jsonl.gz). " + "Required when --enable-rir is True.", + ) + + group.add_argument( + "--rir-prob", + type=float, + default=0.5, + help="Probability of applying RIR augmentation to each utterance.", + ) + group.add_argument( "--input-strategy", type=str, @@ -227,17 +305,40 @@ class LibriSpeechAsrDataModule: sampler_state_dict: The state dict for the training sampler. """ + # Setup augmentation transforms (for noisy dataset) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest("data/fbank/musan_cuts.jsonl.gz") transforms.append( CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") + if self.args.enable_rir: + logging.info("Enable RIR (Room Impulse Response) augmentation") + logging.info(f"Loading RIR paths from {self.args.rir_cuts_path}") + + # Load RIR file paths from rir.scp + rir_paths = [] + try: + with open("data/manifests/rir.scp", "r") as f: + rir_paths = [line.strip() for line in f if line.strip()] + logging.info(f"Found {len(rir_paths)} RIR files") + except FileNotFoundError: + logging.warning("RIR file data/manifests/rir.scp not found, skipping RIR augmentation") + rir_paths = [] + + if rir_paths: + # Use the module-level RandomRIRTransform class with audio-level processing + transforms.append( + RandomRIRTransform(rir_paths, prob=self.args.rir_prob) + ) + else: + logging.info("Disable RIR augmentation") + if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " @@ -278,32 +379,37 @@ class LibriSpeechAsrDataModule: else: logging.info("Disable SpecAugment") + # Create input strategy (same for both clean and noisy - only transforms differ) + input_strategy = eval(self.args.input_strategy)() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + + # Create clean dataset (no augmentation) + # Create train dataset (with augmentations) logging.info("About to create train dataset") + augmentation_details = [] + if transforms: + transform_names = [type(t).__name__ for t in transforms] + augmentation_details.append(f"Cut transforms: {transform_names}") + if input_transforms: + input_transform_names = [type(t).__name__ for t in input_transforms] + augmentation_details.append(f"Input transforms: {input_transform_names}") + + if augmentation_details: + logging.info(f"Train dataset augmentations: {'; '.join(augmentation_details)}") + else: + logging.info("Train dataset: No augmentations will be applied") + + logging.info(f"Train dataset: {len(transforms)} cut transforms, {len(input_transforms)} input transforms") + train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, + input_strategy=input_strategy, + cut_transforms=transforms, # Apply cut augmentations (MUSAN, RIR, concat) + input_transforms=input_transforms, # Apply input augmentations (SpecAugment) return_cuts=self.args.return_cuts, ) - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - + # Create sampler if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( @@ -322,6 +428,7 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, ) + logging.info("About to create train dataloader") if sampler_state_dict is not None: @@ -353,6 +460,10 @@ class LibriSpeechAsrDataModule: ) ] + transforms + # Determine the max_duration for validation + valid_max_duration = self.args.valid_max_duration if self.args.valid_max_duration is not None else self.args.max_duration + logging.info(f"Validation max_duration: {valid_max_duration} seconds") + logging.info("About to create dev dataset") if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( @@ -367,7 +478,7 @@ class LibriSpeechAsrDataModule: ) valid_sampler = DynamicBucketingSampler( cuts_valid, - max_duration=self.args.max_duration, + max_duration=valid_max_duration, shuffle=False, ) logging.info("About to create dev dataloader") @@ -403,6 +514,29 @@ class LibriSpeechAsrDataModule: ) return test_dl + def all_test_dataloaders(self) -> Dict[str, DataLoader]: + """ + Returns all test dataloaders including LibriSpeech and CHiME-4. + + Returns: + Dict[str, DataLoader]: Dictionary with test set names as keys and DataLoaders as values + """ + test_dataloaders = {} + + # LibriSpeech test sets + test_clean_cuts = self.test_clean_cuts() + test_other_cuts = self.test_other_cuts() + + test_dataloaders["test-clean"] = self.test_dataloaders(test_clean_cuts) + test_dataloaders["test-other"] = self.test_dataloaders(test_other_cuts) + + # CHiME-4 test sets + chime4_dls = self.chime4_test_dataloaders() + for test_set_name, dl in chime4_dls.items(): + test_dataloaders[f"chime4-{test_set_name}"] = dl + + return test_dataloaders + @lru_cache() def train_clean_5_cuts(self) -> CutSet: logging.info("mini_librispeech: About to get train-clean-5 cuts") @@ -490,3 +624,137 @@ class LibriSpeechAsrDataModule: def gigaspeech_test_cuts(self) -> CutSet: logging.info("About to get Gigaspeech test cuts") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + + def chime4_test_dataloaders(self) -> Dict[str, DataLoader]: + """Create CHiME-4 test dataloaders for different conditions.""" + from pathlib import Path + + chime4_audio_root = Path("/home/nas/DB/CHiME4/data/audio/16kHz/isolated") + chime4_transcript_root = Path("/home/nas/DB/CHiME4/data/transcriptions") + + test_loaders = {} + + # Define test sets: dt05 (development) and et05 (evaluation) + test_sets = ["dt05_bth", "et05_bth"] # Start with booth (clean) conditions + + for test_set in test_sets: + try: + audio_dir = chime4_audio_root / test_set + transcript_dir = chime4_transcript_root / test_set + + if not audio_dir.exists() or not transcript_dir.exists(): + logging.warning(f"CHiME-4 {test_set} not found, skipping") + continue + + # Create cuts for this test set + cuts = self._create_chime4_cuts(audio_dir, transcript_dir, max_files=50) + + if len(cuts) == 0: + logging.warning(f"No valid cuts for CHiME-4 {test_set}") + continue + + # Create test dataset + test_dataset = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + + # Create sampler + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + + # Create dataloader + test_dl = DataLoader( + test_dataset, + batch_size=None, + sampler=sampler, + num_workers=2, + ) + + test_loaders[test_set] = test_dl + logging.info(f"Created CHiME-4 {test_set} dataloader with {len(cuts)} cuts") + + except Exception as e: + logging.warning(f"Failed to create CHiME-4 {test_set} dataloader: {e}") + + return test_loaders + + def _create_chime4_cuts(self, audio_dir: Path, transcript_dir: Path, max_files: int = 50) -> CutSet: + """Helper to create CutSet from CHiME-4 audio and transcripts.""" + from lhotse import CutSet, Recording, RecordingSet, SupervisionSegment, SupervisionSet + + # Get audio files (limit for testing) + wav_files = sorted(list(audio_dir.glob("*.wav")))[:max_files] + + # Parse transcriptions + transcriptions = {} + for trn_file in transcript_dir.glob("*.trn"): + try: + with open(trn_file, 'r', encoding='utf-8') as f: + line = f.read().strip() + if line: + parts = line.split(' ', 1) + if len(parts) == 2: + utterance_id = parts[0] + text = parts[1] + transcriptions[utterance_id] = text + except Exception as e: + logging.warning(f"Failed to read {trn_file}: {e}") + + # Create recordings and supervisions + recordings = [] + supervisions = [] + + for wav_file in wav_files: + # Extract utterance ID from filename (remove .CH0, etc.) + utterance_id = wav_file.stem + if '.CH' in utterance_id: + utterance_id = utterance_id.split('.CH')[0] + + # Skip if no transcription + if utterance_id not in transcriptions: + continue + + try: + # Create recording + recording = Recording.from_file(wav_file) + recording = Recording( + id=utterance_id, + sources=recording.sources, + sampling_rate=recording.sampling_rate, + num_samples=recording.num_samples, + duration=recording.duration, + channel_ids=recording.channel_ids, + transforms=recording.transforms + ) + recordings.append(recording) + + # Create supervision + text = transcriptions[utterance_id] + supervision = SupervisionSegment( + id=utterance_id, + recording_id=utterance_id, + start=0.0, + duration=recording.duration, + channel=0, + text=text, + language="English" + ) + supervisions.append(supervision) + + except Exception as e: + logging.warning(f"Failed to process {wav_file}: {e}") + continue + + if not recordings: + return CutSet.from_cuts([]) # Empty CutSet + + # Create manifests + recording_set = RecordingSet.from_recordings(recordings) + supervision_set = SupervisionSet.from_segments(supervisions) + cuts = CutSet.from_manifests(recordings=recording_set, supervisions=supervision_set) + + return cuts diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/README.md b/egs/librispeech/ASR/tiny_transducer_ctc/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/train_clean.sh b/egs/librispeech/ASR/train_clean.sh new file mode 100644 index 000000000..6ac6d6170 --- /dev/null +++ b/egs/librispeech/ASR/train_clean.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +# train.sh - LibriSpeech ASR Training Script with Data Augmentation Control +# Usage: bash train.sh + +set -euo pipefail + +# Data Augmentation Controls (modify these as needed) +enable_spec_aug=false # SpecAugment (frequency/time masking) +enable_musan=false # MUSAN noise augmentation +enable_rir=false +enable_cutmix=false +enable_concatenate=false + +# RIR settings (used when enable_rir=true) +rir_cuts_path="data/rir/rir_cuts.jsonl.gz" +rir_prob=0.5 + +# Training parameters +world_size=3 +max_duration=300 +valid_max_duration=15 +num_buckets=200 +num_workers=32 +warm_step=10000 +lang_dir="./data/lang_bpe_5000" +method="ctc-decoding" + +# Model parameters +att_rate=0 # 0 for pure CTC, >0 for CTC+Attention +num_decoder_layers=0 # 0 for pure CTC + +# Other settings +start_epoch=78 +master_port=12346 +sanity_check=false # Set to true for OOM checking (slower) + +# Validation settings +enable_validation=true # Temporarily disable validation to avoid crashes +valid_interval=5000 # Much larger interval if we enable validation later + +# Validation decoding settings +validation_decoding_method="greedy" # "greedy" or "beam" - use greedy for faster validation +validation_search_beam=10.0 # Beam size for validation (only used if method="beam") +validation_output_beam=5.0 # Output beam for validation (only used if method="beam") +validation_skip_wer=false # Skip WER computation for even faster validation (디버깅용 - 이제 false로 변경) + +if [ "$enable_rir" = "true" ]; then + echo " - RIR Path: $rir_cuts_path" + echo " - RIR Probability: $rir_prob" +fi + + +# gdb --args python ./conformer_ctc/train.py +if [ -z "${PYTHONPATH:-}" ]; then + export PYTHONPATH="/tmp/icefall" +else + export PYTHONPATH="${PYTHONPATH}:/tmp/icefall" +fi + +python3 ./conformer_ctc/train.py \ + --master-port $master_port \ + --sanity-check $sanity_check \ + --world-size $world_size \ + --warm-step $warm_step \ + --start-epoch $start_epoch \ + --att-rate $att_rate \ + --num-decoder-layers $num_decoder_layers \ + --num-workers $num_workers \ + --enable-spec-aug $enable_spec_aug \ + --enable-musan $enable_musan \ + --enable-rir $enable_rir \ + --rir-cuts-path $rir_cuts_path \ + --rir-prob $rir_prob \ + --max-duration $max_duration \ + --valid-max-duration $valid_max_duration \ + --num-buckets $num_buckets \ + --bucketing-sampler true \ + --concatenate-cuts $enable_concatenate \ + --duration-factor 1.0 \ + --drop-last true \ + --shuffle true \ + --lang-dir $lang_dir \ + --method $method \ + --enable-validation $enable_validation \ + --valid-interval $valid_interval \ + --validation-decoding-method $validation_decoding_method \ + --validation-search-beam $validation_search_beam \ + --validation-output-beam $validation_output_beam \ + --validation-skip-wer $validation_skip_wer diff --git a/egs/librispeech/ASR/train_noisy.sh b/egs/librispeech/ASR/train_noisy.sh new file mode 100644 index 000000000..814559c46 --- /dev/null +++ b/egs/librispeech/ASR/train_noisy.sh @@ -0,0 +1,96 @@ +#!/bin/# Data Augmentation Controls (modify these as needed) +enable_spec_aug=true # SpecAugment (frequency/time masking) +enable_musan=true # MUSAN noise augmentation +enable_rir=true # RIR (Room Impulse Response) augmentation - FIXED AND RE-ENABLED +enable_cutmix=true # Cut mixing: 두 오디오의 시간 구간을 섞음 +enable_concatenate=true # Cut concatenation: 짧은 발화들을 연결하여 패딩 최소화 +# train.sh - LibriSpeech ASR Training Script with Data Augmentation Control +# Usage: bash train.sh + +set -euo pipefail + +# Data Augmentation Controls (modify these as needed) +enable_spec_aug=true # SpecAugment (frequency/time masking) +enable_musan=true # MUSAN noise augmentation +enable_rir=true # RIR (Room Impulse Response) augmentation - RE-ENABLED +enable_cutmix=true # Cut mixing: 두 오디오의 시간 구간을 섞음 +enable_concatenate=true # Cut concatenation: 짧은 발화들을 연결하여 패딩 최소화 + +# RIR settings (used when enable_rir=true) +rir_cuts_path="data/manifests/rir.scp" # Path to RIR file list (updated to use rir.scp) +rir_prob=0.5 # Probability of applying RIR + +# Training parameters +world_size=4 # Multi-GPU restored since test passed +max_duration=300 # Further reduced from 320 to save memory +valid_max_duration=15 # Very small for multi-GPU safety +num_buckets=200 # Reduced for memory saving +num_workers=24 # Much smaller to save memory +warm_step=40000 +lang_dir="./data/lang_bpe_5000" +method="ctc-decoding" + +# Model parameters +att_rate=0 # 0 for pure CTC, >0 for CTC+Attention +num_decoder_layers=0 # 0 for pure CTC + +# Other settings +start_epoch=0 +master_port=12345 +sanity_check=false # Set to true for OOM checking (slower) + +# Validation settings +enable_validation=true # Set to false to disable validation completely +valid_interval=5000 # Increased from 50 to allow more training before validation + +# Validation decoding settings +validation_decoding_method="greedy" # "greedy" or "beam" - use greedy for faster validation +validation_search_beam=10.0 # Beam size for validation (only used if method="beam") +validation_output_beam=5.0 # Output beam for validation (only used if method="beam") +validation_skip_wer=false # Skip WER computation for even faster validation (디버깅용 - 이제 false로 변경) + +if [ "$enable_rir" = "true" ]; then + echo " - RIR Path: $rir_cuts_path" + echo " - RIR Probability: $rir_prob" +fi + + +# gdb --args python ./conformer_ctc/train.py +if [ -z "${PYTHONPATH:-}" ]; then + export PYTHONPATH="/tmp/icefall" +else + export PYTHONPATH="${PYTHONPATH}:/tmp/icefall" +fi + + +python3 ./conformer_ctc/train.py \ + --master-port $master_port \ + --sanity-check $sanity_check \ + --world-size $world_size \ + --warm-step $warm_step \ + --start-epoch $start_epoch \ + --att-rate $att_rate \ + --num-decoder-layers $num_decoder_layers \ + --num-workers $num_workers \ + --enable-spec-aug $enable_spec_aug \ + --enable-musan $enable_musan \ + --enable-rir $enable_rir \ + --rir-cuts-path $rir_cuts_path \ + --rir-prob $rir_prob \ + --on-the-fly-feats true \ + --max-duration $max_duration \ + --valid-max-duration $valid_max_duration \ + --num-buckets $num_buckets \ + --bucketing-sampler true \ + --concatenate-cuts $enable_concatenate \ + --duration-factor 1.0 \ + --drop-last true \ + --shuffle true \ + --lang-dir $lang_dir \ + --method $method \ + --enable-validation $enable_validation \ + --valid-interval $valid_interval \ + --validation-decoding-method $validation_decoding_method \ + --validation-search-beam $validation_search_beam \ + --validation-output-beam $validation_output_beam \ + --validation-skip-wer $validation_skip_wer diff --git a/egs/librispeech/ASR/training.log b/egs/librispeech/ASR/training.log new file mode 100644 index 000000000..6af3f46d7 --- /dev/null +++ b/egs/librispeech/ASR/training.log @@ -0,0 +1,400 @@ +nohup: ignoring input + - RIR Path: data/manifests/rir.scp + - RIR Probability: 0.5 +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +2025-08-26 22:40:01,609 INFO [train.py:958] (0/3) Training started +2025-08-26 22:40:01,610 INFO [train.py:959] (0/3) Warmup steps: 30000 +2025-08-26 22:40:01,610 INFO [train.py:960] (0/3) { + "att_rate": 0.0, + "attention_dim": 256, + "batch_idx_train": 0, + "beam_size": 10, + "best_train_epoch": -1, + "best_train_loss": Infinity, + "best_valid_epoch": -1, + "best_valid_loss": Infinity, + "bpe_dir": "data/lang_bpe_5000", + "bucketing_sampler": true, + "concatenate_cuts": true, + "drop_last": true, + "duration_factor": 1.0, + "enable_musan": true, + "enable_rir": true, + "enable_spec_aug": true, + "enable_validation": true, + "env_info": { + "IP address": "127.0.1.1", + "hostname": "Attention", + "icefall-git-branch": null, + "icefall-git-date": null, + "icefall-git-sha1": null, + "icefall-path": "/tmp/icefall", + "k2-build-type": "Release", + "k2-git-date": "Mon Jul 14 07:51:57 2025", + "k2-git-sha1": "9399d1b01a6309e54b62d885e93209bcd66c1e7d", + "k2-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/k2/__init__.py", + "k2-version": "1.24.4", + "k2-with-cuda": true, + "lhotse-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/lhotse/__init__.py", + "lhotse-version": "1.31.0.dev+git.273e312.clean", + "python-version": "3.8", + "torch-cuda-available": true, + "torch-cuda-version": "12.1", + "torch-version": "2.4.0+cu121" + }, + "exp_dir": "conformer_ctc/exp", + "feature_dim": 80, + "full_libri": true, + "gap": 1.0, + "input_strategy": "PrecomputedFeatures", + "lang_dir": "data/lang_bpe_5000", + "log_interval": 50, + "lr_factor": 5.0, + "manifest_dir": "data/fbank", + "master_port": 12345, + "max_active_states": 10000, + "max_duration": 300, + "method": "ctc-decoding", + "min_active_states": 30, + "mini_libri": false, + "nhead": 4, + "num_buckets": 200, + "num_decoder_layers": 0, + "num_epochs": 100, + "num_workers": 24, + "on_the_fly_feats": false, + "output_beam": 8.0, + "reduction": "sum", + "reset_interval": 200, + "return_cuts": true, + "rir_cuts_path": "data/manifests/rir.scp", + "rir_prob": 0.5, + "sanity_check": true, + "search_beam": 20.0, + "seed": 42, + "shuffle": true, + "spec_aug_time_warp_factor": 80, + "start_epoch": 0, + "subsampling_factor": 4, + "tensorboard": true, + "use_double_scores": true, + "use_feat_batchnorm": true, + "valid_interval": 5000, + "valid_max_duration": 15, + "validation_decoding_method": "greedy", + "validation_output_beam": 5.0, + "validation_search_beam": 10.0, + "validation_skip_wer": false, + "warm_step": 30000, + "weight_decay": 1e-06, + "world_size": 3 +} +2025-08-26 22:40:01,734 INFO [train.py:958] (1/3) Training started +2025-08-26 22:40:01,734 INFO [train.py:959] (1/3) Warmup steps: 30000 +2025-08-26 22:40:01,734 INFO [train.py:960] (1/3) { + "att_rate": 0.0, + "attention_dim": 256, + "batch_idx_train": 0, + "beam_size": 10, + "best_train_epoch": -1, + "best_train_loss": Infinity, + "best_valid_epoch": -1, + "best_valid_loss": Infinity, + "bpe_dir": "data/lang_bpe_5000", + "bucketing_sampler": true, + "concatenate_cuts": true, + "drop_last": true, + "duration_factor": 1.0, + "enable_musan": true, + "enable_rir": true, + "enable_spec_aug": true, + "enable_validation": true, + "env_info": { + "IP address": "127.0.1.1", + "hostname": "Attention", + "icefall-git-branch": null, + "icefall-git-date": null, + "icefall-git-sha1": null, + "icefall-path": "/tmp/icefall", + "k2-build-type": "Release", + "k2-git-date": "Mon Jul 14 07:51:57 2025", + "k2-git-sha1": "9399d1b01a6309e54b62d885e93209bcd66c1e7d", + "k2-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/k2/__init__.py", + "k2-version": "1.24.4", + "k2-with-cuda": true, + "lhotse-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/lhotse/__init__.py", + "lhotse-version": "1.31.0.dev+git.273e312.clean", + "python-version": "3.8", + "torch-cuda-available": true, + "torch-cuda-version": "12.1", + "torch-version": "2.4.0+cu121" + }, + "exp_dir": "conformer_ctc/exp", + "feature_dim": 80, + "full_libri": true, + "gap": 1.0, + "input_strategy": "PrecomputedFeatures", + "lang_dir": "data/lang_bpe_5000", + "log_interval": 50, + "lr_factor": 5.0, + "manifest_dir": "data/fbank", + "master_port": 12345, + "max_active_states": 10000, + "max_duration": 300, + "method": "ctc-decoding", + "min_active_states": 30, + "mini_libri": false, + "nhead": 4, + "num_buckets": 200, + "num_decoder_layers": 0, + "num_epochs": 100, + "num_workers": 24, + "on_the_fly_feats": false, + "output_beam": 8.0, + "reduction": "sum", + "reset_interval": 200, + "return_cuts": true, + "rir_cuts_path": "data/manifests/rir.scp", + "rir_prob": 0.5, + "sanity_check": true, + "search_beam": 20.0, + "seed": 42, + "shuffle": true, + "spec_aug_time_warp_factor": 80, + "start_epoch": 0, + "subsampling_factor": 4, + "tensorboard": true, + "use_double_scores": true, + "use_feat_batchnorm": true, + "valid_interval": 5000, + "valid_max_duration": 15, + "validation_decoding_method": "greedy", + "validation_output_beam": 5.0, + "validation_search_beam": 10.0, + "validation_skip_wer": false, + "warm_step": 30000, + "weight_decay": 1e-06, + "world_size": 3 +} +2025-08-26 22:40:01,758 INFO [train.py:958] (2/3) Training started +2025-08-26 22:40:01,758 INFO [train.py:959] (2/3) Warmup steps: 30000 +2025-08-26 22:40:01,758 INFO [train.py:960] (2/3) { + "att_rate": 0.0, + "attention_dim": 256, + "batch_idx_train": 0, + "beam_size": 10, + "best_train_epoch": -1, + "best_train_loss": Infinity, + "best_valid_epoch": -1, + "best_valid_loss": Infinity, + "bpe_dir": "data/lang_bpe_5000", + "bucketing_sampler": true, + "concatenate_cuts": true, + "drop_last": true, + "duration_factor": 1.0, + "enable_musan": true, + "enable_rir": true, + "enable_spec_aug": true, + "enable_validation": true, + "env_info": { + "IP address": "127.0.1.1", + "hostname": "Attention", + "icefall-git-branch": null, + "icefall-git-date": null, + "icefall-git-sha1": null, + "icefall-path": "/tmp/icefall", + "k2-build-type": "Release", + "k2-git-date": "Mon Jul 14 07:51:57 2025", + "k2-git-sha1": "9399d1b01a6309e54b62d885e93209bcd66c1e7d", + "k2-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/k2/__init__.py", + "k2-version": "1.24.4", + "k2-with-cuda": true, + "lhotse-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/lhotse/__init__.py", + "lhotse-version": "1.31.0.dev+git.273e312.clean", + "python-version": "3.8", + "torch-cuda-available": true, + "torch-cuda-version": "12.1", + "torch-version": "2.4.0+cu121" + }, + "exp_dir": "conformer_ctc/exp", + "feature_dim": 80, + "full_libri": true, + "gap": 1.0, + "input_strategy": "PrecomputedFeatures", + "lang_dir": "data/lang_bpe_5000", + "log_interval": 50, + "lr_factor": 5.0, + "manifest_dir": "data/fbank", + "master_port": 12345, + "max_active_states": 10000, + "max_duration": 300, + "method": "ctc-decoding", + "min_active_states": 30, + "mini_libri": false, + "nhead": 4, + "num_buckets": 200, + "num_decoder_layers": 0, + "num_epochs": 100, + "num_workers": 24, + "on_the_fly_feats": false, + "output_beam": 8.0, + "reduction": "sum", + "reset_interval": 200, + "return_cuts": true, + "rir_cuts_path": "data/manifests/rir.scp", + "rir_prob": 0.5, + "sanity_check": true, + "search_beam": 20.0, + "seed": 42, + "shuffle": true, + "spec_aug_time_warp_factor": 80, + "start_epoch": 0, + "subsampling_factor": 4, + "tensorboard": true, + "use_double_scores": true, + "use_feat_batchnorm": true, + "valid_interval": 5000, + "valid_max_duration": 15, + "validation_decoding_method": "greedy", + "validation_output_beam": 5.0, + "validation_search_beam": 10.0, + "validation_skip_wer": false, + "warm_step": 30000, + "weight_decay": 1e-06, + "world_size": 3 +} +2025-08-26 22:40:01,977 INFO [train.py:1012] (0/3) About to create model +2025-08-26 22:40:02,103 INFO [train.py:1012] (1/3) About to create model +2025-08-26 22:40:02,125 INFO [train.py:1012] (2/3) About to create model +/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer was not TransformerEncoderLayer + warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") +2025-08-26 22:40:03,809 INFO [asr_datamodule.py:539] (0/3) About to get the shuffled train-clean-100, train-clean-360 and train-other-500 cuts +2025-08-26 22:40:03,811 INFO [asr_datamodule.py:303] (0/3) Enable MUSAN +2025-08-26 22:40:03,811 INFO [asr_datamodule.py:304] (0/3) About to get Musan cuts +/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer was not TransformerEncoderLayer + warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") +2025-08-26 22:40:03,957 INFO [asr_datamodule.py:539] (2/3) About to get the shuffled train-clean-100, train-clean-360 and train-other-500 cuts +/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer was not TransformerEncoderLayer + warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") +2025-08-26 22:40:03,957 INFO [asr_datamodule.py:539] (1/3) About to get the shuffled train-clean-100, train-clean-360 and train-other-500 cuts +2025-08-26 22:40:03,958 INFO [asr_datamodule.py:303] (2/3) Enable MUSAN +2025-08-26 22:40:03,958 INFO [asr_datamodule.py:304] (2/3) About to get Musan cuts +2025-08-26 22:40:03,959 INFO [asr_datamodule.py:303] (1/3) Enable MUSAN +2025-08-26 22:40:03,959 INFO [asr_datamodule.py:304] (1/3) About to get Musan cuts +2025-08-26 22:40:06,823 INFO [asr_datamodule.py:313] (2/3) Enable RIR (Room Impulse Response) augmentation +2025-08-26 22:40:06,823 INFO [asr_datamodule.py:314] (2/3) Loading RIR paths from data/manifests/rir.scp +2025-08-26 22:40:06,846 INFO [asr_datamodule.py:321] (2/3) Found 60536 RIR files +2025-08-26 22:40:06,847 INFO [asr_datamodule.py:313] (0/3) Enable RIR (Room Impulse Response) augmentation +2025-08-26 22:40:06,847 INFO [asr_datamodule.py:314] (0/3) Loading RIR paths from data/manifests/rir.scp +2025-08-26 22:40:06,863 INFO [asr_datamodule.py:313] (1/3) Enable RIR (Room Impulse Response) augmentation +2025-08-26 22:40:06,863 INFO [asr_datamodule.py:314] (1/3) Loading RIR paths from data/manifests/rir.scp +2025-08-26 22:40:06,863 INFO [asr_datamodule.py:335] (2/3) Using cut concatenation with duration factor 1.0 and gap 1.0. +2025-08-26 22:40:06,863 INFO [asr_datamodule.py:350] (2/3) Enable SpecAugment +2025-08-26 22:40:06,864 INFO [asr_datamodule.py:351] (2/3) Time warp factor: 80 +2025-08-26 22:40:06,864 INFO [asr_datamodule.py:361] (2/3) Num frame mask: 10 +2025-08-26 22:40:06,864 INFO [asr_datamodule.py:381] (2/3) About to create train dataset +2025-08-26 22:40:06,864 INFO [asr_datamodule.py:391] (2/3) Train dataset augmentations: Cut transforms: ['CutConcatenate', 'CutMix', 'RandomRIRTransform']; Input transforms: ['SpecAugment'] +2025-08-26 22:40:06,864 INFO [asr_datamodule.py:395] (2/3) Train dataset: 3 cut transforms, 1 input transforms +2025-08-26 22:40:06,864 INFO [asr_datamodule.py:406] (2/3) Using DynamicBucketingSampler. +2025-08-26 22:40:06,870 INFO [asr_datamodule.py:321] (0/3) Found 60536 RIR files +2025-08-26 22:40:06,885 INFO [asr_datamodule.py:321] (1/3) Found 60536 RIR files +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:335] (0/3) Using cut concatenation with duration factor 1.0 and gap 1.0. +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:350] (0/3) Enable SpecAugment +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:351] (0/3) Time warp factor: 80 +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:361] (0/3) Num frame mask: 10 +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:381] (0/3) About to create train dataset +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:391] (0/3) Train dataset augmentations: Cut transforms: ['CutConcatenate', 'CutMix', 'RandomRIRTransform']; Input transforms: ['SpecAugment'] +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:395] (0/3) Train dataset: 3 cut transforms, 1 input transforms +2025-08-26 22:40:06,887 INFO [asr_datamodule.py:406] (0/3) Using DynamicBucketingSampler. +2025-08-26 22:40:06,902 INFO [asr_datamodule.py:335] (1/3) Using cut concatenation with duration factor 1.0 and gap 1.0. +2025-08-26 22:40:06,903 INFO [asr_datamodule.py:350] (1/3) Enable SpecAugment +2025-08-26 22:40:06,903 INFO [asr_datamodule.py:351] (1/3) Time warp factor: 80 +2025-08-26 22:40:06,903 INFO [asr_datamodule.py:361] (1/3) Num frame mask: 10 +2025-08-26 22:40:06,903 INFO [asr_datamodule.py:381] (1/3) About to create train dataset +2025-08-26 22:40:06,903 INFO [asr_datamodule.py:391] (1/3) Train dataset augmentations: Cut transforms: ['CutConcatenate', 'CutMix', 'RandomRIRTransform']; Input transforms: ['SpecAugment'] +2025-08-26 22:40:06,903 INFO [asr_datamodule.py:395] (1/3) Train dataset: 3 cut transforms, 1 input transforms +2025-08-26 22:40:06,903 INFO [asr_datamodule.py:406] (1/3) Using DynamicBucketingSampler. +2025-08-26 22:40:07,548 INFO [asr_datamodule.py:424] (2/3) About to create train dataloader +2025-08-26 22:40:07,551 INFO [asr_datamodule.py:556] (2/3) About to get dev-clean cuts +2025-08-26 22:40:07,552 INFO [asr_datamodule.py:457] (2/3) Validation max_duration: 15 seconds +2025-08-26 22:40:07,552 INFO [asr_datamodule.py:459] (2/3) About to create dev dataset +2025-08-26 22:40:07,561 INFO [asr_datamodule.py:424] (0/3) About to create train dataloader +2025-08-26 22:40:07,564 INFO [asr_datamodule.py:556] (0/3) About to get dev-clean cuts +2025-08-26 22:40:07,565 INFO [asr_datamodule.py:457] (0/3) Validation max_duration: 15 seconds +2025-08-26 22:40:07,565 INFO [asr_datamodule.py:459] (0/3) About to create dev dataset +2025-08-26 22:40:07,586 INFO [asr_datamodule.py:424] (1/3) About to create train dataloader +2025-08-26 22:40:07,589 INFO [asr_datamodule.py:556] (1/3) About to get dev-clean cuts +2025-08-26 22:40:07,590 INFO [asr_datamodule.py:457] (1/3) Validation max_duration: 15 seconds +2025-08-26 22:40:07,590 INFO [asr_datamodule.py:459] (1/3) About to create dev dataset +2025-08-26 22:40:07,718 INFO [asr_datamodule.py:476] (2/3) About to create dev dataloader +2025-08-26 22:40:07,718 INFO [train.py:1068] (2/3) Validation set size: 2703 utterances +2025-08-26 22:40:07,718 INFO [train.py:1129] (2/3) Sanity check -- see if any of the batches in epoch 0 would cause OOM. +2025-08-26 22:40:07,732 INFO [asr_datamodule.py:476] (0/3) About to create dev dataloader +2025-08-26 22:40:07,732 INFO [train.py:1068] (0/3) Validation set size: 2703 utterances +2025-08-26 22:40:07,732 INFO [train.py:1129] (0/3) Sanity check -- see if any of the batches in epoch 0 would cause OOM. +2025-08-26 22:40:07,756 INFO [asr_datamodule.py:476] (1/3) About to create dev dataloader +2025-08-26 22:40:07,757 INFO [train.py:1068] (1/3) Validation set size: 2703 utterances +2025-08-26 22:40:07,757 INFO [train.py:1129] (1/3) Sanity check -- see if any of the batches in epoch 0 would cause OOM. +Loaded 100 RIR recordings for augmentation +W0826 22:44:22.869694 125173910820672 torch/multiprocessing/spawn.py:146] Terminating process 1270919 via signal SIGTERM +W0826 22:44:22.870538 125173910820672 torch/multiprocessing/spawn.py:146] Terminating process 1270920 via signal SIGTERM +Traceback (most recent call last): + File "./conformer_ctc/train.py", line 1415, in + main() + File "./conformer_ctc/train.py", line 1408, in main + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn + return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes + while not context.join(): + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 189, in join + raise ProcessRaisedException(msg, error_index, failed_process.pid) +torch.multiprocessing.spawn.ProcessRaisedException: + +-- Process 0 terminated with the following error: +Traceback (most recent call last): + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap + fn(i, *args) + File "/home/hdd2/jenny/ASRToolkit/icefall/egs/librispeech/ASR/conformer_ctc/train.py", line 1071, in run + scan_pessimistic_batches_for_oom( + File "/home/hdd2/jenny/ASRToolkit/icefall/egs/librispeech/ASR/conformer_ctc/train.py", line 1134, in scan_pessimistic_batches_for_oom + batch = train_dl.dataset[cuts] + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/lhotse/dataset/speech_recognition.py", line 109, in __getitem__ + cuts = tnfm(cuts) +TypeError: __call__() missing 1 required positional argument: 'sampling_rate' + diff --git a/egs/librispeech/ASR/training_fixed.log b/egs/librispeech/ASR/training_fixed.log new file mode 100644 index 000000000..75a4dabf0 --- /dev/null +++ b/egs/librispeech/ASR/training_fixed.log @@ -0,0 +1,386 @@ +nohup: ignoring input + - RIR Path: data/manifests/rir.scp + - RIR Probability: 0.5 +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +fatal: detected dubious ownership in repository at '/home/hdd2/jenny/ASRToolkit/icefall' +To add an exception for this directory, call: + + git config --global --add safe.directory /home/hdd2/jenny/ASRToolkit/icefall +2025-08-27 10:02:23,634 INFO [train.py:958] (0/3) Training started +2025-08-27 10:02:23,635 INFO [train.py:959] (0/3) Warmup steps: 30000 +2025-08-27 10:02:23,635 INFO [train.py:960] (0/3) { + "att_rate": 0.0, + "attention_dim": 256, + "batch_idx_train": 0, + "beam_size": 10, + "best_train_epoch": -1, + "best_train_loss": Infinity, + "best_valid_epoch": -1, + "best_valid_loss": Infinity, + "bpe_dir": "data/lang_bpe_5000", + "bucketing_sampler": true, + "concatenate_cuts": true, + "drop_last": true, + "duration_factor": 1.0, + "enable_musan": true, + "enable_rir": true, + "enable_spec_aug": true, + "enable_validation": true, + "env_info": { + "IP address": "127.0.1.1", + "hostname": "Attention", + "icefall-git-branch": null, + "icefall-git-date": null, + "icefall-git-sha1": null, + "icefall-path": "/tmp/icefall", + "k2-build-type": "Release", + "k2-git-date": "Mon Jul 14 07:51:57 2025", + "k2-git-sha1": "9399d1b01a6309e54b62d885e93209bcd66c1e7d", + "k2-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/k2/__init__.py", + "k2-version": "1.24.4", + "k2-with-cuda": true, + "lhotse-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/lhotse/__init__.py", + "lhotse-version": "1.31.0.dev+git.273e312.clean", + "python-version": "3.8", + "torch-cuda-available": true, + "torch-cuda-version": "12.1", + "torch-version": "2.4.0+cu121" + }, + "exp_dir": "conformer_ctc/exp", + "feature_dim": 80, + "full_libri": true, + "gap": 1.0, + "input_strategy": "PrecomputedFeatures", + "lang_dir": "data/lang_bpe_5000", + "log_interval": 50, + "lr_factor": 5.0, + "manifest_dir": "data/fbank", + "master_port": 12345, + "max_active_states": 10000, + "max_duration": 300, + "method": "ctc-decoding", + "min_active_states": 30, + "mini_libri": false, + "nhead": 4, + "num_buckets": 200, + "num_decoder_layers": 0, + "num_epochs": 100, + "num_workers": 24, + "on_the_fly_feats": false, + "output_beam": 8.0, + "reduction": "sum", + "reset_interval": 200, + "return_cuts": true, + "rir_cuts_path": "data/manifests/rir.scp", + "rir_prob": 0.5, + "sanity_check": true, + "search_beam": 20.0, + "seed": 42, + "shuffle": true, + "spec_aug_time_warp_factor": 80, + "start_epoch": 0, + "subsampling_factor": 4, + "tensorboard": true, + "use_double_scores": true, + "use_feat_batchnorm": true, + "valid_interval": 5000, + "valid_max_duration": 15, + "validation_decoding_method": "greedy", + "validation_output_beam": 5.0, + "validation_search_beam": 10.0, + "validation_skip_wer": false, + "warm_step": 30000, + "weight_decay": 1e-06, + "world_size": 3 +} +2025-08-27 10:02:23,656 INFO [train.py:958] (2/3) Training started +2025-08-27 10:02:23,656 INFO [train.py:959] (2/3) Warmup steps: 30000 +2025-08-27 10:02:23,656 INFO [train.py:960] (2/3) { + "att_rate": 0.0, + "attention_dim": 256, + "batch_idx_train": 0, + "beam_size": 10, + "best_train_epoch": -1, + "best_train_loss": Infinity, + "best_valid_epoch": -1, + "best_valid_loss": Infinity, + "bpe_dir": "data/lang_bpe_5000", + "bucketing_sampler": true, + "concatenate_cuts": true, + "drop_last": true, + "duration_factor": 1.0, + "enable_musan": true, + "enable_rir": true, + "enable_spec_aug": true, + "enable_validation": true, + "env_info": { + "IP address": "127.0.1.1", + "hostname": "Attention", + "icefall-git-branch": null, + "icefall-git-date": null, + "icefall-git-sha1": null, + "icefall-path": "/tmp/icefall", + "k2-build-type": "Release", + "k2-git-date": "Mon Jul 14 07:51:57 2025", + "k2-git-sha1": "9399d1b01a6309e54b62d885e93209bcd66c1e7d", + "k2-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/k2/__init__.py", + "k2-version": "1.24.4", + "k2-with-cuda": true, + "lhotse-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/lhotse/__init__.py", + "lhotse-version": "1.31.0.dev+git.273e312.clean", + "python-version": "3.8", + "torch-cuda-available": true, + "torch-cuda-version": "12.1", + "torch-version": "2.4.0+cu121" + }, + "exp_dir": "conformer_ctc/exp", + "feature_dim": 80, + "full_libri": true, + "gap": 1.0, + "input_strategy": "PrecomputedFeatures", + "lang_dir": "data/lang_bpe_5000", + "log_interval": 50, + "lr_factor": 5.0, + "manifest_dir": "data/fbank", + "master_port": 12345, + "max_active_states": 10000, + "max_duration": 300, + "method": "ctc-decoding", + "min_active_states": 30, + "mini_libri": false, + "nhead": 4, + "num_buckets": 200, + "num_decoder_layers": 0, + "num_epochs": 100, + "num_workers": 24, + "on_the_fly_feats": false, + "output_beam": 8.0, + "reduction": "sum", + "reset_interval": 200, + "return_cuts": true, + "rir_cuts_path": "data/manifests/rir.scp", + "rir_prob": 0.5, + "sanity_check": true, + "search_beam": 20.0, + "seed": 42, + "shuffle": true, + "spec_aug_time_warp_factor": 80, + "start_epoch": 0, + "subsampling_factor": 4, + "tensorboard": true, + "use_double_scores": true, + "use_feat_batchnorm": true, + "valid_interval": 5000, + "valid_max_duration": 15, + "validation_decoding_method": "greedy", + "validation_output_beam": 5.0, + "validation_search_beam": 10.0, + "validation_skip_wer": false, + "warm_step": 30000, + "weight_decay": 1e-06, + "world_size": 3 +} +2025-08-27 10:02:23,770 INFO [train.py:958] (1/3) Training started +2025-08-27 10:02:23,770 INFO [train.py:959] (1/3) Warmup steps: 30000 +2025-08-27 10:02:23,770 INFO [train.py:960] (1/3) { + "att_rate": 0.0, + "attention_dim": 256, + "batch_idx_train": 0, + "beam_size": 10, + "best_train_epoch": -1, + "best_train_loss": Infinity, + "best_valid_epoch": -1, + "best_valid_loss": Infinity, + "bpe_dir": "data/lang_bpe_5000", + "bucketing_sampler": true, + "concatenate_cuts": true, + "drop_last": true, + "duration_factor": 1.0, + "enable_musan": true, + "enable_rir": true, + "enable_spec_aug": true, + "enable_validation": true, + "env_info": { + "IP address": "127.0.1.1", + "hostname": "Attention", + "icefall-git-branch": null, + "icefall-git-date": null, + "icefall-git-sha1": null, + "icefall-path": "/tmp/icefall", + "k2-build-type": "Release", + "k2-git-date": "Mon Jul 14 07:51:57 2025", + "k2-git-sha1": "9399d1b01a6309e54b62d885e93209bcd66c1e7d", + "k2-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/k2/__init__.py", + "k2-version": "1.24.4", + "k2-with-cuda": true, + "lhotse-path": "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/lhotse/__init__.py", + "lhotse-version": "1.31.0.dev+git.273e312.clean", + "python-version": "3.8", + "torch-cuda-available": true, + "torch-cuda-version": "12.1", + "torch-version": "2.4.0+cu121" + }, + "exp_dir": "conformer_ctc/exp", + "feature_dim": 80, + "full_libri": true, + "gap": 1.0, + "input_strategy": "PrecomputedFeatures", + "lang_dir": "data/lang_bpe_5000", + "log_interval": 50, + "lr_factor": 5.0, + "manifest_dir": "data/fbank", + "master_port": 12345, + "max_active_states": 10000, + "max_duration": 300, + "method": "ctc-decoding", + "min_active_states": 30, + "mini_libri": false, + "nhead": 4, + "num_buckets": 200, + "num_decoder_layers": 0, + "num_epochs": 100, + "num_workers": 24, + "on_the_fly_feats": false, + "output_beam": 8.0, + "reduction": "sum", + "reset_interval": 200, + "return_cuts": true, + "rir_cuts_path": "data/manifests/rir.scp", + "rir_prob": 0.5, + "sanity_check": true, + "search_beam": 20.0, + "seed": 42, + "shuffle": true, + "spec_aug_time_warp_factor": 80, + "start_epoch": 0, + "subsampling_factor": 4, + "tensorboard": true, + "use_double_scores": true, + "use_feat_batchnorm": true, + "valid_interval": 5000, + "valid_max_duration": 15, + "validation_decoding_method": "greedy", + "validation_output_beam": 5.0, + "validation_search_beam": 10.0, + "validation_skip_wer": false, + "warm_step": 30000, + "weight_decay": 1e-06, + "world_size": 3 +} +2025-08-27 10:02:24,048 INFO [train.py:1012] (0/3) About to create model +2025-08-27 10:02:24,072 INFO [train.py:1012] (2/3) About to create model +2025-08-27 10:02:24,123 INFO [train.py:1012] (1/3) About to create model +/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer was not TransformerEncoderLayer + warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") +2025-08-27 10:02:26,096 INFO [asr_datamodule.py:537] (0/3) About to get the shuffled train-clean-100, train-clean-360 and train-other-500 cuts +2025-08-27 10:02:26,112 INFO [asr_datamodule.py:301] (0/3) Enable MUSAN +2025-08-27 10:02:26,112 INFO [asr_datamodule.py:302] (0/3) About to get Musan cuts +/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer was not TransformerEncoderLayer + warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") +2025-08-27 10:02:26,234 INFO [asr_datamodule.py:537] (2/3) About to get the shuffled train-clean-100, train-clean-360 and train-other-500 cuts +2025-08-27 10:02:26,235 INFO [asr_datamodule.py:301] (2/3) Enable MUSAN +2025-08-27 10:02:26,236 INFO [asr_datamodule.py:302] (2/3) About to get Musan cuts +/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer was not TransformerEncoderLayer + warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") +2025-08-27 10:02:26,238 INFO [asr_datamodule.py:537] (1/3) About to get the shuffled train-clean-100, train-clean-360 and train-other-500 cuts +2025-08-27 10:02:26,239 INFO [asr_datamodule.py:301] (1/3) Enable MUSAN +2025-08-27 10:02:26,239 INFO [asr_datamodule.py:302] (1/3) About to get Musan cuts +2025-08-27 10:02:28,823 INFO [asr_datamodule.py:311] (0/3) Enable RIR (Room Impulse Response) augmentation +2025-08-27 10:02:28,823 INFO [asr_datamodule.py:312] (0/3) Loading RIR paths from data/manifests/rir.scp +2025-08-27 10:02:28,829 INFO [asr_datamodule.py:311] (1/3) Enable RIR (Room Impulse Response) augmentation +2025-08-27 10:02:28,829 INFO [asr_datamodule.py:312] (1/3) Loading RIR paths from data/manifests/rir.scp +2025-08-27 10:02:28,845 INFO [asr_datamodule.py:319] (0/3) Found 60536 RIR files +2025-08-27 10:02:28,851 INFO [asr_datamodule.py:319] (1/3) Found 60536 RIR files +2025-08-27 10:02:29,081 INFO [asr_datamodule.py:333] (0/3) Using cut concatenation with duration factor 1.0 and gap 1.0. +2025-08-27 10:02:29,081 INFO [asr_datamodule.py:333] (1/3) Using cut concatenation with duration factor 1.0 and gap 1.0. +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:348] (0/3) Enable SpecAugment +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:348] (1/3) Enable SpecAugment +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:349] (0/3) Time warp factor: 80 +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:349] (1/3) Time warp factor: 80 +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:359] (0/3) Num frame mask: 10 +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:359] (1/3) Num frame mask: 10 +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:379] (0/3) About to create train dataset +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:379] (1/3) About to create train dataset +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:389] (0/3) Train dataset augmentations: Cut transforms: ['CutConcatenate', 'CutMix', 'RandomRIRTransform']; Input transforms: ['SpecAugment'] +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:389] (1/3) Train dataset augmentations: Cut transforms: ['CutConcatenate', 'CutMix', 'RandomRIRTransform']; Input transforms: ['SpecAugment'] +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:393] (0/3) Train dataset: 3 cut transforms, 1 input transforms +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:393] (1/3) Train dataset: 3 cut transforms, 1 input transforms +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:404] (0/3) Using DynamicBucketingSampler. +2025-08-27 10:02:29,082 INFO [asr_datamodule.py:404] (1/3) Using DynamicBucketingSampler. +2025-08-27 10:02:29,203 INFO [asr_datamodule.py:311] (2/3) Enable RIR (Room Impulse Response) augmentation +2025-08-27 10:02:29,204 INFO [asr_datamodule.py:312] (2/3) Loading RIR paths from data/manifests/rir.scp +2025-08-27 10:02:29,228 INFO [asr_datamodule.py:319] (2/3) Found 60536 RIR files +2025-08-27 10:02:29,236 INFO [asr_datamodule.py:333] (2/3) Using cut concatenation with duration factor 1.0 and gap 1.0. +2025-08-27 10:02:29,237 INFO [asr_datamodule.py:348] (2/3) Enable SpecAugment +2025-08-27 10:02:29,237 INFO [asr_datamodule.py:349] (2/3) Time warp factor: 80 +2025-08-27 10:02:29,237 INFO [asr_datamodule.py:359] (2/3) Num frame mask: 10 +2025-08-27 10:02:29,237 INFO [asr_datamodule.py:379] (2/3) About to create train dataset +2025-08-27 10:02:29,237 INFO [asr_datamodule.py:389] (2/3) Train dataset augmentations: Cut transforms: ['CutConcatenate', 'CutMix', 'RandomRIRTransform']; Input transforms: ['SpecAugment'] +2025-08-27 10:02:29,237 INFO [asr_datamodule.py:393] (2/3) Train dataset: 3 cut transforms, 1 input transforms +2025-08-27 10:02:29,237 INFO [asr_datamodule.py:404] (2/3) Using DynamicBucketingSampler. +2025-08-27 10:02:29,771 INFO [asr_datamodule.py:422] (0/3) About to create train dataloader +2025-08-27 10:02:29,772 INFO [asr_datamodule.py:422] (1/3) About to create train dataloader +2025-08-27 10:02:29,774 INFO [asr_datamodule.py:554] (0/3) About to get dev-clean cuts +2025-08-27 10:02:29,775 INFO [asr_datamodule.py:554] (1/3) About to get dev-clean cuts +2025-08-27 10:02:29,788 INFO [asr_datamodule.py:455] (0/3) Validation max_duration: 15 seconds +2025-08-27 10:02:29,788 INFO [asr_datamodule.py:455] (1/3) Validation max_duration: 15 seconds +2025-08-27 10:02:29,788 INFO [asr_datamodule.py:457] (0/3) About to create dev dataset +2025-08-27 10:02:29,788 INFO [asr_datamodule.py:457] (1/3) About to create dev dataset +2025-08-27 10:02:29,920 INFO [asr_datamodule.py:422] (2/3) About to create train dataloader +2025-08-27 10:02:29,923 INFO [asr_datamodule.py:554] (2/3) About to get dev-clean cuts +2025-08-27 10:02:29,924 INFO [asr_datamodule.py:455] (2/3) Validation max_duration: 15 seconds +2025-08-27 10:02:29,924 INFO [asr_datamodule.py:457] (2/3) About to create dev dataset +2025-08-27 10:02:29,958 INFO [asr_datamodule.py:474] (0/3) About to create dev dataloader +2025-08-27 10:02:29,958 INFO [train.py:1068] (0/3) Validation set size: 2703 utterances +2025-08-27 10:02:29,958 INFO [train.py:1129] (0/3) Sanity check -- see if any of the batches in epoch 0 would cause OOM. +2025-08-27 10:02:29,961 INFO [asr_datamodule.py:474] (1/3) About to create dev dataloader +2025-08-27 10:02:29,961 INFO [train.py:1068] (1/3) Validation set size: 2703 utterances +2025-08-27 10:02:29,961 INFO [train.py:1129] (1/3) Sanity check -- see if any of the batches in epoch 0 would cause OOM. +2025-08-27 10:02:30,094 INFO [asr_datamodule.py:474] (2/3) About to create dev dataloader +2025-08-27 10:02:30,094 INFO [train.py:1068] (2/3) Validation set size: 2703 utterances +2025-08-27 10:02:30,094 INFO [train.py:1129] (2/3) Sanity check -- see if any of the batches in epoch 0 would cause OOM. +W0827 10:02:58.158139 127413766444864 torch/multiprocessing/spawn.py:146] Terminating process 1291938 via signal SIGTERM +W0827 10:02:58.158930 127413766444864 torch/multiprocessing/spawn.py:146] Terminating process 1291939 via signal SIGTERM +Traceback (most recent call last): + File "./conformer_ctc/train.py", line 1415, in + main() + File "./conformer_ctc/train.py", line 1408, in main + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn + return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes + while not context.join(): + File "/home/jenny/miniconda3/envs/jenny/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 170, in join + raise ProcessExitedException( +torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGKILL diff --git a/egs/librispeech/ASR/transducer/README.md b/egs/librispeech/ASR/transducer/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer/__init__.py b/egs/librispeech/ASR/transducer/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer/joiner.py b/egs/librispeech/ASR/transducer/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/README.md b/egs/librispeech/ASR/transducer_lstm/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/__init__.py b/egs/librispeech/ASR/transducer_lstm/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/encoder_interface.py b/egs/librispeech/ASR/transducer_lstm/encoder_interface.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/joiner.py b/egs/librispeech/ASR/transducer_lstm/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_lstm/noam.py b/egs/librispeech/ASR/transducer_lstm/noam.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/README.md b/egs/librispeech/ASR/transducer_stateless/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/__init__.py b/egs/librispeech/ASR/transducer_stateless/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless2/__init__.py b/egs/librispeech/ASR/transducer_stateless2/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless2/joiner.py b/egs/librispeech/ASR/transducer_stateless2/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless2/model.py b/egs/librispeech/ASR/transducer_stateless2/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/README.md b/egs/librispeech/ASR/transducer_stateless_multi_datasets/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/__init__.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/upload_to_huggingface.py b/egs/librispeech/ASR/upload_to_huggingface.py new file mode 100644 index 000000000..5f0b975fb --- /dev/null +++ b/egs/librispeech/ASR/upload_to_huggingface.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +Script to upload icefall conformer CTC model to Hugging Face Hub +""" + +import os +import torch +import logging +from pathlib import Path +from typing import Dict, Any +import json +import shutil + +# Hugging Face imports +try: + from huggingface_hub import HfApi, create_repo, upload_folder + from huggingface_hub.utils import RepositoryNotFoundError +except ImportError: + print("Please install huggingface_hub: pip install huggingface_hub") + exit(1) + +def create_model_card(model_info: Dict[str, Any]) -> str: + """Create a model card for the Hugging Face model""" + + model_card = f"""--- +language: en +license: apache-2.0 +tags: +- speech +- audio +- automatic-speech-recognition +- icefall +- conformer +- ctc +library_name: icefall +datasets: +- librispeech_asr +metrics: +- wer +--- + +# {model_info['model_name']} + +This is a Conformer CTC model trained with icefall on LibriSpeech dataset. + +## Model Description + +- **Architecture**: Conformer with CTC loss +- **Training Framework**: icefall +- **Dataset**: LibriSpeech ASR +- **Language**: English +- **Sample Rate**: 16kHz + +## Model Details + +- **Model Size**: {model_info.get('num_params', 'Unknown')} parameters +- **Feature Dimension**: {model_info.get('feature_dim', 80)} +- **Attention Dimension**: {model_info.get('attention_dim', 256)} +- **Number of Heads**: {model_info.get('nhead', 4)} +- **Subsampling Factor**: {model_info.get('subsampling_factor', 4)} + +## Training Information + +- **Best Valid Loss**: {model_info.get('best_valid_loss', 'Unknown')} +- **Training Epochs**: {model_info.get('epoch', 'Unknown')} +- **Optimizer**: Adam +- **Framework**: icefall + k2 + lhotse + +## Usage + +```python +# Load model with icefall +from icefall.checkpoint import load_checkpoint +from conformer import Conformer +import torch + +# Model configuration +model = Conformer( + num_features=80, + nhead=4, + d_model=256, + num_classes=5000, # Adjust based on your vocab size + subsampling_factor=4, + num_decoder_layers=0, + vgg_frontend=False, + use_feat_batchnorm=True, +) + +# Load checkpoint +load_checkpoint("best-valid-loss.pt", model) +model.eval() +``` + +## Citation + +If you use this model, please cite: + +```bibtex +@misc{{icefall2021, + title={{Icefall: A speech recognition toolkit with PyTorch}}, + author={{The icefall development team}}, + howpublished={{\\url{{https://github.com/k2-fsa/icefall}}}}, + year={{2021}} +}} +``` + +## License + +This model is released under the Apache 2.0 License. +""" + return model_card + +def extract_model_info(checkpoint_path: Path) -> Dict[str, Any]: + """Extract model information from checkpoint""" + + try: + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + model_info = { + 'model_name': 'icefall-conformer-ctc-librispeech', + 'checkpoint_path': str(checkpoint_path) + } + + # Extract information from checkpoint + if 'epoch' in checkpoint: + model_info['epoch'] = checkpoint['epoch'] + + if 'best_valid_loss' in checkpoint: + model_info['best_valid_loss'] = checkpoint['best_valid_loss'] + + if 'model' in checkpoint: + # Count parameters + num_params = sum(p.numel() for p in checkpoint['model'].values()) + model_info['num_params'] = f"{num_params:,}" + + # Model architecture info (you might need to adjust these) + model_info.update({ + 'feature_dim': 80, + 'attention_dim': 256, + 'nhead': 4, + 'subsampling_factor': 4 + }) + + return model_info + + except Exception as e: + logging.error(f"Error extracting model info: {e}") + return {'model_name': 'icefall-conformer-ctc-librispeech'} + +def create_config_json(model_info: Dict[str, Any]) -> Dict[str, Any]: + """Create a config.json file for the model""" + + config = { + "architectures": ["Conformer"], + "model_type": "conformer_ctc", + "framework": "icefall", + "feature_dim": model_info.get('feature_dim', 80), + "attention_dim": model_info.get('attention_dim', 256), + "nhead": model_info.get('nhead', 4), + "subsampling_factor": model_info.get('subsampling_factor', 4), + "num_decoder_layers": 0, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "sample_rate": 16000, + "language": "en" + } + + return config + +def upload_to_huggingface( + checkpoint_path: Path, + repo_name: str, + token: str = None, + private: bool = False +): + """Upload icefall model to Hugging Face Hub""" + + # Create temporary directory for upload + temp_dir = Path("./hf_upload_temp") + temp_dir.mkdir(exist_ok=True) + + try: + # Extract model information + print("Extracting model information...") + model_info = extract_model_info(checkpoint_path) + + # Copy model file + print("Copying model file...") + shutil.copy2(checkpoint_path, temp_dir / "best-valid-loss.pt") + + # Create model card + print("Creating model card...") + model_card = create_model_card(model_info) + with open(temp_dir / "README.md", "w") as f: + f.write(model_card) + + # Create config.json + print("Creating config.json...") + config = create_config_json(model_info) + with open(temp_dir / "config.json", "w") as f: + json.dump(config, f, indent=2) + + # Create additional files + print("Creating additional files...") + + # Create inference example + inference_example = '''#!/usr/bin/env python3 +""" +Example inference script for icefall Conformer CTC model +""" + +import torch +from pathlib import Path + +def load_model(model_path: str): + """Load the icefall Conformer model""" + + # You'll need to have icefall installed and import the Conformer class + # from conformer import Conformer + # from icefall.checkpoint import load_checkpoint + + # model = Conformer( + # num_features=80, + # nhead=4, + # d_model=256, + # num_classes=5000, # Adjust based on vocab + # subsampling_factor=4, + # num_decoder_layers=0, + # vgg_frontend=False, + # use_feat_batchnorm=True, + # ) + + # load_checkpoint(model_path, model) + # model.eval() + # return model + + pass + +if __name__ == "__main__": + model = load_model("best-valid-loss.pt") + print("Model loaded successfully!") +''' + + with open(temp_dir / "inference_example.py", "w") as f: + f.write(inference_example) + + # Create requirements.txt + requirements = """torch>=1.9.0 +torchaudio>=0.9.0 +k2 +lhotse +icefall +""" + with open(temp_dir / "requirements.txt", "w") as f: + f.write(requirements) + + # Initialize Hugging Face API + api = HfApi(token=token) + + # Create repository + print(f"Creating repository: {repo_name}") + try: + create_repo( + repo_id=repo_name, + token=token, + private=private, + repo_type="model" + ) + print(f"✅ Repository {repo_name} created successfully!") + except Exception as e: + if "already exists" in str(e).lower(): + print(f"Repository {repo_name} already exists, continuing...") + else: + raise e + + # Upload files + print("Uploading files to Hugging Face Hub...") + upload_folder( + folder_path=temp_dir, + repo_id=repo_name, + token=token, + commit_message="Upload icefall Conformer CTC model" + ) + + print(f"✅ Model uploaded successfully to: https://huggingface.co/{repo_name}") + + except Exception as e: + print(f"❌ Error uploading model: {e}") + raise e + + finally: + # Clean up + print("Cleaning up temporary files...") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + +def main(): + """Main function""" + + # Configuration + checkpoint_path = Path("/home/hdd2/jenny/ASRToolkit/icefall/egs/librispeech/ASR/conformer_ctc/exp-cleanASR/models/best-valid-loss.pt") + + # Get user input + repo_name = input("Enter repository name (e.g., username/model-name): ").strip() + if not repo_name: + print("Repository name is required!") + return + + token = input("Enter your Hugging Face token (or press Enter to use saved token): ").strip() + if not token: + token = None # Will use saved token from huggingface-cli login + + private = input("Make repository private? (y/N): ").strip().lower() == 'y' + + # Check if checkpoint exists + if not checkpoint_path.exists(): + print(f"❌ Checkpoint not found: {checkpoint_path}") + return + + print(f"📁 Checkpoint path: {checkpoint_path}") + print(f"🔗 Repository: {repo_name}") + print(f"🔒 Private: {private}") + + confirm = input("\\nProceed with upload? (y/N): ").strip().lower() + if confirm != 'y': + print("Upload cancelled.") + return + + # Upload model + upload_to_huggingface( + checkpoint_path=checkpoint_path, + repo_name=repo_name, + token=token, + private=private + ) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/.gitignore b/egs/librispeech/ASR/zipformer/.gitignore old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/decode_stream.py b/egs/librispeech/ASR/zipformer/decode_stream.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/label_smoothing.py b/egs/librispeech/ASR/zipformer/label_smoothing.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/streaming_beam_search.py b/egs/librispeech/ASR/zipformer/streaming_beam_search.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_ctc/__init__.py b/egs/librispeech/ASR/zipformer_ctc/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_ctc/decoder.py b/egs/librispeech/ASR/zipformer_ctc/decoder.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_ctc/model.py b/egs/librispeech/ASR/zipformer_ctc/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_mmi/README.md b/egs/librispeech/ASR/zipformer_mmi/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_mmi/__init__.py b/egs/librispeech/ASR/zipformer_mmi/__init__.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/zipformer_mmi/model.py b/egs/librispeech/ASR/zipformer_mmi/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/asr_datamodule.py b/egs/librispeech/SSL/hubert/asr_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/attention_module.py b/egs/librispeech/SSL/hubert/attention_module.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/dataset.py b/egs/librispeech/SSL/hubert/dataset.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/decode.py b/egs/librispeech/SSL/hubert/decode.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/decode_ce.py b/egs/librispeech/SSL/hubert/decode_ce.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/hubert.py b/egs/librispeech/SSL/hubert/hubert.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/hubert_ce.py b/egs/librispeech/SSL/hubert/hubert_ce.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/ssl_datamodule.py b/egs/librispeech/SSL/hubert/ssl_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/utils.py b/egs/librispeech/SSL/hubert/utils.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/hubert/wav2vec2_module.py b/egs/librispeech/SSL/hubert/wav2vec2_module.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py b/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/local/prepare_char.py b/egs/librispeech/SSL/local/prepare_char.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/local/prepare_lang.py b/egs/librispeech/SSL/local/prepare_lang.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/local/process_librispeech4finetune.py b/egs/librispeech/SSL/local/process_librispeech4finetune.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/local/process_librispeech4pretrain.py b/egs/librispeech/SSL/local/process_librispeech4pretrain.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/local/process_raw_cuts.py b/egs/librispeech/SSL/local/process_raw_cuts.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/decode.py b/egs/librispeech/SSL/zipformer/decode.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/hubert_ce.py b/egs/librispeech/SSL/zipformer/hubert_ce.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/utils.py b/egs/librispeech/SSL/zipformer/utils.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/wav2vec2_module.py b/egs/librispeech/SSL/zipformer/wav2vec2_module.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/README.md b/egs/librispeech/WSASR/README.md old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/conformer_ctc2/conformer.py b/egs/librispeech/WSASR/conformer_ctc2/conformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/conformer_ctc2/subsampling.py b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/conformer_ctc2/transformer.py b/egs/librispeech/WSASR/conformer_ctc2/transformer.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/figures/del.png b/egs/librispeech/WSASR/figures/del.png old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/figures/ins.png b/egs/librispeech/WSASR/figures/ins.png old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/figures/otc_emission.drawio.png b/egs/librispeech/WSASR/figures/otc_emission.drawio.png old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/figures/otc_g.png b/egs/librispeech/WSASR/figures/otc_g.png old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png b/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/figures/sub.png b/egs/librispeech/WSASR/figures/sub.png old mode 100644 new mode 100755 diff --git a/egs/librispeech/WSASR/local/filter_cuts.py b/egs/librispeech/WSASR/local/filter_cuts.py old mode 100644 new mode 100755 diff --git a/egs/libritts/ASR/README.md b/egs/libritts/ASR/README.md deleted file mode 100644 index 138f4ae80..000000000 --- a/egs/libritts/ASR/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# Introduction - -LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. -The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. -The main differences from the LibriSpeech corpus are listed below: -1. The audio files are at 24kHz sampling rate. -2. The speech is split at sentence breaks. -3. Both original and normalized texts are included. -4. Contextual information (e.g., neighbouring sentences) can be extracted. -5. Utterances with significant background noise are excluded. -For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. - - -This recipe includes some different ASR models trained with [LibriTTS](https://openslr.org/60/). - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -| | Encoder | Decoder | -|---------------------------------------|---------------------|--------------------| -| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | - -The decoder is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/libritts/ASR/RESULTS.md b/egs/libritts/ASR/RESULTS.md deleted file mode 100644 index 574f81eb6..000000000 --- a/egs/libritts/ASR/RESULTS.md +++ /dev/null @@ -1,58 +0,0 @@ -# Results - -## zipformer (zipformer + pruned stateless transducer) - -See for more details. - -[zipformer](./zipformer) - -### Non-streaming - -#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M - -You can find a pretrained model, training logs, decoding logs, and decoding results at: - - -You can use to deploy it. - -| decoding method | test-clean | test-other | comment | -|----------------------|------------|------------|--------------------| -| greedy_search | 2.83 | 5.91 | --epoch 30 --avg 5 | -| modified_beam_search | 2.80 | 5.87 | --epoch 30 --avg 5 | -| fast_beam_search | 2.87 | 5.86 | --epoch 30 --avg 5 | -| greedy_search | 2.76 | 5.68 | --epoch 40 --avg 16| -| modified_beam_search | 2.74 | 5.66 | --epoch 40 --avg 16| -| fast_beam_search | 2.75 | 5.67 | --epoch 40 --avg 16| -| greedy_search | 2.74 | 5.67 | --epoch 50 --avg 30| -| modified_beam_search | 2.73 | 5.58 | --epoch 50 --avg 30| -| fast_beam_search | 2.78 | 5.61 | --epoch 50 --avg 30| - - -The training command is: -```bash -export CUDA_VISIBLE_DEVICES="0,1" -./zipformer/train.py \ - --world-size 2 \ - --num-epochs 50 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 0 \ - --full-libri 1 \ - --max-duration 3600 -``` -This was used on 2 Nvidia A800 GPUs, you'll need to adjust the `CUDA_VISIBLE_DEVICES`, `--world-size` and `--max-duration` according to your hardware. - -The decoding command is: -```bash -export CUDA_VISIBLE_DEVICES="0" -for m in greedy_search modified_beam_search fast_beam_search; do - ./zipformer/decode.py \ - --epoch 50 \ - --avg 30 \ - --use-averaged-model 1 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method $m -done -``` diff --git a/egs/libritts/ASR/local/compile_hlg.py b/egs/libritts/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/libritts/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/compile_lg.py b/egs/libritts/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/libritts/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/compute_fbank_libritts.py b/egs/libritts/ASR/local/compute_fbank_libritts.py deleted file mode 100755 index b6e2a4c43..000000000 --- a/egs/libritts/ASR/local/compute_fbank_libritts.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao,) -# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the LibriTTS dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=True, - help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", - ) - parser.add_argument( - "--sampling-rate", - type=int, - default=24000, - help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""", - ) - - return parser.parse_args() - - -def compute_fbank_libritts( - dataset: Optional[str] = None, - sampling_rate: int = 24000, - perturb_speed: Optional[bool] = True, -): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(32, os.cpu_count()) - - num_mel_bins = 80 - - if dataset is None: - dataset_parts = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ) - else: - dataset_parts = dataset.split(" ", -1) - - prefix = "libritts" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if sampling_rate != 24000: - logging.info(f"Resampling audio to {sampling_rate}Hz") - cut_set = cut_set.resample(sampling_rate) - if "train" in partition: - if perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - - compute_fbank_libritts( - dataset=args.dataset, - sampling_rate=args.sampling_rate, - perturb_speed=args.perturb_speed, - ) diff --git a/egs/libritts/ASR/local/compute_fbank_musan.py b/egs/libritts/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/libritts/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py b/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/display_manifest_statistics.py b/egs/libritts/ASR/local/display_manifest_statistics.py deleted file mode 100755 index ddd022c96..000000000 --- a/egs/libritts/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,341 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - paths = [ - "./data/fbank/libritts_cuts_train-clean-100.jsonl.gz", - "./data/fbank/libritts_cuts_train-clean-360.jsonl.gz", - "./data/fbank/libritts_cuts_train-other-500.jsonl.gz", - "./data/fbank/libritts_cuts_dev-clean.jsonl.gz", - "./data/fbank/libritts_cuts_dev-other.jsonl.gz", - "./data/fbank/libritts_cuts_test-clean.jsonl.gz", - "./data/fbank/libritts_cuts_test-other.jsonl.gz", - ] - for path in paths: - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -./data/fbank/libritts_cuts_train-clean-100.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 33236 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 53:47:18 _ -________________________________________ -_ mean _ 5.8 _ -________________________________________ -_ std _ 4.6 _ -________________________________________ -_ min _ 0.2 _ -________________________________________ -_ 25% _ 2.4 _ -________________________________________ -_ 50% _ 4.5 _ -________________________________________ -_ 75% _ 7.9 _ -________________________________________ -_ 99% _ 21.4 _ -________________________________________ -_ 99.5% _ 23.7 _ -________________________________________ -_ 99.9% _ 27.8 _ -________________________________________ -_ max _ 33.2 _ -________________________________________ -_ Recordings available: _ 33236 _ -________________________________________ -_ Features available: _ 33236 _ -________________________________________ -_ Supervisions available: _ 33236 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 53:47:18 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/fbank/libritts_cuts_train-clean-360.jsonl.gz statistics: -_________________________________________ -_ Cuts count: _ 116500 _ -_________________________________________ -_ Total duration (hh:mm:ss) _ 191:17:42 _ -_________________________________________ -_ mean _ 5.9 _ -_________________________________________ -_ std _ 4.6 _ -_________________________________________ -_ min _ 0.1 _ -_________________________________________ -_ 25% _ 2.4 _ -_________________________________________ -_ 50% _ 4.6 _ -_________________________________________ -_ 75% _ 8.1 _ -_________________________________________ -_ 99% _ 21.3 _ -_________________________________________ -_ 99.5% _ 23.4 _ -_________________________________________ -_ 99.9% _ 27.4 _ -_________________________________________ -_ max _ 40.4 _ -_________________________________________ -_ Recordings available: _ 116500 _ -_________________________________________ -_ Features available: _ 116500 _ -_________________________________________ -_ Supervisions available: _ 116500 _ -_________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -___________________________________________________________________ -_ Total speech duration _ 191:17:42 _ 100.00% of recording _ -___________________________________________________________________ -_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _ -___________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -___________________________________________________________________ - -./data/fbank/libritts_cuts_train-other-500.jsonl.gz statistics: -_________________________________________ -_ Cuts count: _ 205043 _ -_________________________________________ -_ Total duration (hh:mm:ss) _ 310:04:36 _ -_________________________________________ -_ mean _ 5.4 _ -_________________________________________ -_ std _ 4.4 _ -_________________________________________ -_ min _ 0.1 _ -_________________________________________ -_ 25% _ 2.3 _ -_________________________________________ -_ 50% _ 4.2 _ -_________________________________________ -_ 75% _ 7.3 _ -_________________________________________ -_ 99% _ 20.6 _ -_________________________________________ -_ 99.5% _ 22.8 _ -_________________________________________ -_ 99.9% _ 27.4 _ -_________________________________________ -_ max _ 43.9 _ -_________________________________________ -_ Recordings available: _ 205043 _ -_________________________________________ -_ Features available: _ 205043 _ -_________________________________________ -_ Supervisions available: _ 205043 _ -_________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -___________________________________________________________________ -_ Total speech duration _ 310:04:36 _ 100.00% of recording _ -___________________________________________________________________ -_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _ -___________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -___________________________________________________________________ - -./data/fbank/libritts_cuts_dev-clean.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 5736 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 08:58:13 _ -________________________________________ -_ mean _ 5.6 _ -________________________________________ -_ std _ 4.3 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 2.4 _ -________________________________________ -_ 50% _ 4.4 _ -________________________________________ -_ 75% _ 7.8 _ -________________________________________ -_ 99% _ 19.9 _ -________________________________________ -_ 99.5% _ 21.9 _ -________________________________________ -_ 99.9% _ 26.3 _ -________________________________________ -_ max _ 30.1 _ -________________________________________ -_ Recordings available: _ 5736 _ -________________________________________ -_ Features available: _ 5736 _ -________________________________________ -_ Supervisions available: _ 5736 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 08:58:13 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/fbank/libritts_cuts_dev-other.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 4613 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 06:25:52 _ -________________________________________ -_ mean _ 5.0 _ -________________________________________ -_ std _ 4.1 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 2.2 _ -________________________________________ -_ 50% _ 3.8 _ -________________________________________ -_ 75% _ 6.5 _ -________________________________________ -_ 99% _ 19.7 _ -________________________________________ -_ 99.5% _ 24.5 _ -________________________________________ -_ 99.9% _ 31.0 _ -________________________________________ -_ max _ 32.6 _ -________________________________________ -_ Recordings available: _ 4613 _ -________________________________________ -_ Features available: _ 4613 _ -________________________________________ -_ Supervisions available: _ 4613 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 06:25:52 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/fbank/libritts_cuts_test-clean.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 4837 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 08:34:09 _ -________________________________________ -_ mean _ 6.4 _ -________________________________________ -_ std _ 5.1 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 2.4 _ -________________________________________ -_ 50% _ 4.8 _ -________________________________________ -_ 75% _ 8.9 _ -________________________________________ -_ 99% _ 22.6 _ -________________________________________ -_ 99.5% _ 24.4 _ -________________________________________ -_ 99.9% _ 29.6 _ -________________________________________ -_ max _ 36.7 _ -________________________________________ -_ Recordings available: _ 4837 _ -________________________________________ -_ Features available: _ 4837 _ -________________________________________ -_ Supervisions available: _ 4837 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 08:34:09 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/fbank/libritts_cuts_test-other.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 5120 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 06:41:31 _ -________________________________________ -_ mean _ 4.7 _ -________________________________________ -_ std _ 3.8 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 1.8 _ -________________________________________ -_ 50% _ 3.6 _ -________________________________________ -_ 75% _ 6.5 _ -________________________________________ -_ 99% _ 17.8 _ -________________________________________ -_ 99.5% _ 20.4 _ -________________________________________ -_ 99.9% _ 23.8 _ -________________________________________ -_ max _ 27.3 _ -________________________________________ -_ Recordings available: _ 5120 _ -________________________________________ -_ Features available: _ 5120 _ -________________________________________ -_ Supervisions available: _ 5120 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 06:41:31 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ -""" diff --git a/egs/libritts/ASR/local/download_lm.py b/egs/libritts/ASR/local/download_lm.py deleted file mode 120000 index c9668bd2d..000000000 --- a/egs/libritts/ASR/local/download_lm.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/download_lm.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/norm_text.py b/egs/libritts/ASR/local/norm_text.py deleted file mode 120000 index dea3c051f..000000000 --- a/egs/libritts/ASR/local/norm_text.py +++ /dev/null @@ -1 +0,0 @@ -../../../libriheavy/ASR/local/norm_text.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang.py b/egs/libritts/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/libritts/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang_bpe.py b/egs/libritts/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/libritts/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang_fst.py b/egs/libritts/ASR/local/prepare_lang_fst.py deleted file mode 120000 index c5787c534..000000000 --- a/egs/libritts/ASR/local/prepare_lang_fst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lm_training_data.py b/egs/libritts/ASR/local/prepare_lm_training_data.py deleted file mode 120000 index abc00d421..000000000 --- a/egs/libritts/ASR/local/prepare_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/train_bpe_model.py b/egs/libritts/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/libritts/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/validate_bpe_lexicon.py b/egs/libritts/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/libritts/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/validate_manifest.py b/egs/libritts/ASR/local/validate_manifest.py deleted file mode 100755 index abd4da88a..000000000 --- a/egs/libritts/ASR/local/validate_manifest.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao,) -# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/libritts_cuts_train-all-shuf.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest -from lhotse.dataset.speech_recognition import validate_for_asr - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest(manifest) - assert isinstance(cut_set, CutSet) - - validate_for_asr(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh deleted file mode 100755 index 9d9ce8f87..000000000 --- a/egs/libritts/ASR/prepare.sh +++ /dev/null @@ -1,194 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=0 -stop_stage=100 -sampling_rate=16000 -nj=32 -perturb_speed=true -vocab_sizes=( - # 5000 - # 2000 - # 1000 - 500 -) - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download LM" # we directly use the librispeech lm here - mkdir -p $dl_dir/lm - if [ ! -e $dl_dir/lm/.done ]; then - ./local/download_lm.py --out-dir=$dl_dir/lm - touch $dl_dir/lm/.done - fi -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/LibriTTS, - # you can create a symlink - # - # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS - # - if [ ! -d $dl_dir/LibriTTS ]; then - lhotse download libritts $dl_dir - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/musan - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare LibriTTS manifest" - # We assume that you have downloaded the LibriTTS corpus - # to $dl_dir/LibriTTS - mkdir -p data/manifests - if [ ! -e data/manifests/.libritts.done ]; then - lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests - touch data/manifests/.libritts.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute Fbank for LibriTTS" - mkdir -p data/fbank - if [ ! -e data/fbank/.libritts.done ]; then - ./local/compute_fbank_libritts.py \ - --sampling-rate $sampling_rate \ - --perturb-speed $perturb_speed - touch data/fbank/.libritts.done - fi - - # Here we shuffle and combine the train-clean-100, train-clean-360 and - # train-other-500 together to form the training set. - if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then - cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz - fi - - if [ ! -e data/fbank/.libritts-validated.done ]; then - log "Validating data/fbank for LibriTTS" - ./local/validate_manifest.py \ - data/fbank/libritts_cuts_train-all-shuf.jsonl.gz - touch data/fbank/.libritts-validated.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Train BPE model for normalized text" - - if [ ! -f data/text ]; then - gunzip -c data/manifests/libritts_supervisions_train-clean-100.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py > data/text - - gunzip -c data/manifests/libritts_supervisions_train-clean-360.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py >> data/text - - gunzip -c data/manifests/libritts_supervisions_train-other-500.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py >> data/text - fi - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - - cp data/text $lang_dir/text - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/text - fi - done -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare phone based lang" - lang_dir=data/lang_phone - mkdir -p $lang_dir - - if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then - log "No lexicon file in $dl_dir/lm, please run :" - log "prepare.sh --stage -1 --stop-stage -1" - exit -1 - fi - - if [ ! -f $lang_dir/lexicon.txt ]; then - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/lm/librispeech-lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi -fi diff --git a/egs/libritts/ASR/prepare_lm.sh b/egs/libritts/ASR/prepare_lm.sh deleted file mode 100755 index 1c690983b..000000000 --- a/egs/libritts/ASR/prepare_lm.sh +++ /dev/null @@ -1,264 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -# This script generate Ngram LM / NNLM and related files that needed by decoding. - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/lm -# This directory contains the following files downloaded from -# http://www.openslr.org/resources/11 -# -# - 3-gram.pruned.1e-7.arpa.gz -# - 3-gram.pruned.1e-7.arpa -# - 4-gram.arpa.gz -# - 4-gram.arpa -# - librispeech-vocab.txt -# - librispeech-lexicon.txt -# - librispeech-lm-norm.txt.gz -# - -. prepare.sh --stage -1 --stop-stage 6 || exit 1 - -log "Running prepare_lm.sh" - -stage=0 -stop_stage=100 - -. shared/parse_options.sh || exit 1 - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Prepare BPE based lexicon." - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt $lang_dir - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare word level G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt - fi - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/HL.fst ]; then - ./local/prepare_lang_fst.py \ - --lang-dir $lang_dir \ - --ngram-G ./data/lm/G_3_gram.fst.txt - fi - done -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir - done -fi - -# Compile LG for RNN-T fast_beam_search decoding -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compile LG" - ./local/compile_lg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir - done -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare token level ngram G" - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - > $lang_dir/transcript_tokens.txt - fi - - for ngram in 2 3 4 5; do - if [ ! -f $lang_dir/${ngram}gram.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order ${ngram} \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/${ngram}gram.arpa - fi - - if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=${ngram} \ - $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt - fi - done - done -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Generate NNLM training data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - lang_dir=data/lang_bpe_${vocab_size} - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ - --lm-archive $out_dir/lm_data.pt - done -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Generate NNLM validation data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/valid.txt ]; then - gunzip -c data/manifests/libritts_supervisions_dev-clean.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py > $out_dir/valid.txt - - gunzip -c data/manifests/libritts_supervisions_dev-other.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py >> $out_dir/valid.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/valid.txt \ - --lm-archive $out_dir/lm_data-valid.pt - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Generate NNLM test data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/test.txt ]; then - gunzip -c data/manifests/libritts_supervisions_test-clean.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py > $out_dir/test.txt - - gunzip -c data/manifests/libritts_supervisions_test-other.jsonl.gz \ - | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py >> $out_dir/test.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/test.txt \ - --lm-archive $out_dir/lm_data-test.pt - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Sort NNLM training data" - # Sort LM training data by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of BPE tokens - # in a sentence. - - for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-valid.pt \ - --out-lm-data $out_dir/sorted_lm_data-valid.pt \ - --out-statistics $out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-test.pt \ - --out-lm-data $out_dir/sorted_lm_data-test.pt \ - --out-statistics $out_dir/statistics-test.txt - done -fi diff --git a/egs/libritts/ASR/shared b/egs/libritts/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/libritts/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/.gitignore b/egs/libritts/ASR/zipformer/.gitignore deleted file mode 100644 index e47ac1582..000000000 --- a/egs/libritts/ASR/zipformer/.gitignore +++ /dev/null @@ -1 +0,0 @@ -swoosh.pdf diff --git a/egs/libritts/ASR/zipformer/asr_datamodule.py b/egs/libritts/ASR/zipformer/asr_datamodule.py deleted file mode 100644 index dab834303..000000000 --- a/egs/libritts/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriTTSAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. libritts test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="""When enabled, use the entire LibriTTS training set. - Otherwise, use the 100h subset.""", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)() - ), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" - ) diff --git a/egs/libritts/ASR/zipformer/attention_decoder.py b/egs/libritts/ASR/zipformer/attention_decoder.py deleted file mode 120000 index 384e1b95e..000000000 --- a/egs/libritts/ASR/zipformer/attention_decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/beam_search.py b/egs/libritts/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/libritts/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py deleted file mode 100755 index d77aa5962..000000000 --- a/egs/libritts/ASR/zipformer/ctc_decode.py +++ /dev/null @@ -1,992 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Liyong Guo, -# Quandong Wang, -# Zengwei Yao) -# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) ctc-greedy-search -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --decoding-method ctc-greedy-search - -(2) ctc-decoding -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --decoding-method ctc-decoding - -(3) 1best -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --decoding-method 1best - -(4) nbest -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --decoding-method nbest - -(5) nbest-rescoring -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --nbest-scale 1.0 \ - --lm-dir data/lm \ - --decoding-method nbest-rescoring - -(6) whole-lattice-rescoring -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --nbest-scale 1.0 \ - --lm-dir data/lm \ - --decoding-method whole-lattice-rescoring - -(7) attention-decoder-rescoring-no-ngram -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --use-attention-decoder 1 \ - --max-duration 100 \ - --decoding-method attention-decoder-rescoring-no-ngram - -(8) attention-decoder-rescoring-with-ngram -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --use-attention-decoder 1 \ - --max-duration 100 \ - --hlg-scale 0.6 \ - --nbest-scale 1.0 \ - --lm-dir data/lm \ - --decoding-method attention-decoder-rescoring-with-ngram -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriTTSAsrDataModule -from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params, normalize_text - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import ( - ctc_greedy_search, - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder_no_ngram, - rescore_with_attention_decoder_with_ngram, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="ctc-decoding", - help="""Decoding method. - Supported values are: - - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (3) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (4) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (5) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (6) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - you have trained an RNN LM using ./rnn_lm/train.py - - (7) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding - lattice, rescore them with the attention decoder. - - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM - rescored lattice, rescore them with the attention decoder. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--hlg-scale", - type=float, - default=0.6, - help="""The scale to be applied to `hlg.scores`. - """, - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The n-gram LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt - """, - ) - - parser.add_argument( - "--skip-scoring", - type=str2bool, - default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""", - ) - - add_model_arguments(parser) - - return parser - - -def get_decoding_params() -> AttributeDict: - """Parameters for decoding.""" - params = AttributeDict( - { - "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. - - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. - - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. - - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - if params.decoding_method == "ctc-greedy-search": - hyps = ctc_greedy_search(ctc_output, encoder_out_lens) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(hyps) - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-greedy-search" - return {key: hyps} - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - torch.div( - supervisions["start_frame"], - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - supervisions["num_frames"], - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.decoding_method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} # note: returns words - - if params.decoding_method == "attention-decoder-rescoring-no-ngram": - best_path_dict = rescore_with_attention_decoder_no_ngram( - lattice=lattice, - num_paths=params.num_paths, - attention_decoder=model.attention_decoder, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - nbest_scale=params.nbest_scale, - ) - ans = dict() - for a_scale_str, best_path in best_path_dict.items(): - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - ans[a_scale_str] = hyps - return ans - - if params.decoding_method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa - return {key: hyps} - - if params.decoding_method in ["1best", "nbest"]: - if params.decoding_method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no-rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} # note: returns BPE tokens - - assert params.decoding_method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder-rescoring-with-ngram", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.decoding_method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.decoding_method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "attention-decoder-rescoring-with-ngram": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=None, - ) - best_path_dict = rescore_with_attention_decoder_with_ngram( - lattice=rescored_lattice, - num_paths=params.num_paths, - attention_decoder=model.attention_decoder, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - nbest_scale=params.nbest_scale, - ) - else: - assert False, f"Unsupported decoding method: {params.decoding_method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - word_table: - It is the word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - G=G, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_asr_output( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - """ - Save text produced by ASR. - """ - for key, results in results_dict.items(): - - recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - - results = sorted(results) - store_transcripts(filename=recogs_filename, texts=results) - - logging.info(f"The transcripts are stored in {recogs_filename}") - - -def save_wer_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - if params.decoding_method in ( - "attention-decoder-rescoring-with-ngram", - "whole-lattice-rescoring", - ): - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - - test_set_wers = dict() - for key, results in results_dict.items(): - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w", encoding="utf8") as fd: - wer = write_error_stats( - fd, f"{test_set_name}_{key}", results, enable_log=enable_log - ) - test_set_wers[key] = wer - - logging.info(f"Wrote detailed error stats to {errs_filename}") - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - - wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - - with open(wer_filename, "w", encoding="utf8") as fd: - print("settings\tWER", file=fd) - for key, val in test_set_wers: - print(f"{key}\t{val}", file=fd) - - s = f"\nFor {test_set_name}, WER of different settings are:\n" - note = f"\tbest for {test_set_name}" - for key, val in test_set_wers: - s += f"{key}\t{val}{note}\n" - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriTTSAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - # enable AudioCache - set_caching_enabled(True) # lhotse - - assert params.decoding_method in ( - "ctc-greedy-search", - "ctc-decoding", - "1best", - "nbest", - "nbest-rescoring", - "whole-lattice-rescoring", - "nbest-oracle", - "attention-decoder-rescoring-no-ngram", - "attention-decoder-rescoring-with-ngram", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}_avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"_chunk-{params.chunk_size}" - params.suffix += f"_left-context-{params.left_context_frames}" - - if params.use_averaged_model: - params.suffix += "_use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - params.vocab_size = num_classes - # and are defined in local/train_bpe_model.py - params.blank_id = 0 - params.eos_id = 1 - params.sos_id = 1 - - if params.decoding_method in [ - "ctc-greedy-search", - "ctc-decoding", - "attention-decoder-rescoring-no-ngram", - ]: - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - HLG.scores *= params.hlg_scale - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.decoding_method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder-rescoring-with-ngram", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.decoding_method in [ - "whole-lattice-rescoring", - "attention-decoder-rescoring-with-ngram", - ]: - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - libritts = LibriTTSAsrDataModule(args) - - test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) - test_other_cuts = libritts.test_other_cuts().map(normalize_text) - - test_clean_dl = libritts.test_dataloaders(test_clean_cuts) - test_other_dl = libritts.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - ) - - save_asr_output( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - if not params.skip_scoring: - save_wer_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py deleted file mode 100755 index 759d9d50a..000000000 --- a/egs/libritts/ASR/zipformer/decode.py +++ /dev/null @@ -1,1086 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriTTSAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params, normalize_text - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--skip-scoring", - type=str2bool, - default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) - prefix = f"{params.decoding_method}" - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - prefix += f"_beam-{params.beam}" - prefix += f"_max-contexts-{params.max_contexts}" - prefix += f"_max-states-{params.max_states}" - if "nbest" in params.decoding_method: - prefix += f"_num-paths-{params.num_paths}" - prefix += f"_nbest-scale-{params.nbest_scale}" - if "LG" in params.decoding_method: - prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - - return {prefix: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix += f"_beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"_context-score-{params.context_score}" - return {prefix: hyps} - else: - prefix += f"_beam-size-{params.beam_size}" - return {prefix: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_asr_output( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - """ - Save text produced by ASR. - """ - for key, results in results_dict.items(): - - recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - - results = sorted(results) - store_transcripts(filename=recogs_filename, texts=results) - - logging.info(f"The transcripts are stored in {recogs_filename}") - - -def save_wer_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], -): - """ - Save WER and per-utterance word alignments. - """ - test_set_wers = dict() - for key, results in results_dict.items(): - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w", encoding="utf8") as fd: - wer = write_error_stats( - fd, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info(f"Wrote detailed error stats to {errs_filename}") - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - - wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - - with open(wer_filename, "w", encoding="utf8") as fd: - print("settings\tWER", file=fd) - for key, val in test_set_wers: - print(f"{key}\t{val}", file=fd) - - s = f"\nFor {test_set_name}, WER of different settings are:\n" - note = f"\tbest for {test_set_name}" - for key, val in test_set_wers: - s += f"{key}\t{val}{note}\n" - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriTTSAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - # enable AudioCache - set_caching_enabled(True) # lhotse - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}_avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"_chunk-{params.chunk_size}" - params.suffix += f"_left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"_beam-{params.beam}" - params.suffix += f"_max-contexts-{params.max_contexts}" - params.suffix += f"_max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"_nbest-scale-{params.nbest_scale}" - params.suffix += f"_num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"_context-{params.context_size}" - params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "_use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append((sp.encode(line.strip()), 0.0)) - context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - libritts = LibriTTSAsrDataModule(args) - - test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) - test_other_cuts = libritts.test_other_cuts().map(normalize_text) - - test_clean_dl = libritts.test_dataloaders(test_clean_cuts) - test_other_dl = libritts.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_asr_output( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - if not params.skip_scoring: - save_wer_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libritts/ASR/zipformer/decode_stream.py b/egs/libritts/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/libritts/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/decoder.py b/egs/libritts/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/libritts/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/encoder_interface.py b/egs/libritts/ASR/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/libritts/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-ctc.py b/egs/libritts/ASR/zipformer/export-onnx-ctc.py deleted file mode 120000 index f9d756352..000000000 --- a/egs/libritts/ASR/zipformer/export-onnx-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py deleted file mode 120000 index 652346001..000000000 --- a/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-streaming.py b/egs/libritts/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/libritts/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx.py b/egs/libritts/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/libritts/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export.py b/egs/libritts/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/libritts/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/generate_averaged_model.py b/egs/libritts/ASR/zipformer/generate_averaged_model.py deleted file mode 120000 index 5a015ee6c..000000000 --- a/egs/libritts/ASR/zipformer/generate_averaged_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained.py b/egs/libritts/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/libritts/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py b/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py deleted file mode 120000 index 9a8da5844..000000000 --- a/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py b/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 120000 index 1962351e9..000000000 --- a/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/joiner.py b/egs/libritts/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/libritts/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/label_smoothing.py b/egs/libritts/ASR/zipformer/label_smoothing.py deleted file mode 120000 index 175c633cc..000000000 --- a/egs/libritts/ASR/zipformer/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/model.py b/egs/libritts/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/libritts/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/my_profile.py b/egs/libritts/ASR/zipformer/my_profile.py deleted file mode 120000 index 3a90b2628..000000000 --- a/egs/libritts/ASR/zipformer/my_profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_check.py b/egs/libritts/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/libritts/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_decode.py b/egs/libritts/ASR/zipformer/onnx_decode.py deleted file mode 100755 index 6f09cc8f7..000000000 --- a/egs/libritts/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1,326 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX exported models and uses them to decode the test sets. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -2. Run this file - -./zipformer/onnx_decode.py \ - --exp-dir $repo/exp \ - --max-duration 600 \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ -""" - - -import argparse -import logging -import time -from pathlib import Path -from typing import List, Tuple - -import torch -import torch.nn as nn -from asr_datamodule import LibriTTSAsrDataModule -from k2 import SymbolTable -from onnx_pretrained import OnnxModel, greedy_search -from train import normalize_text - -from icefall.utils import setup_logger, store_transcripts, write_error_stats - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - - return parser - - -def decode_one_batch( - model: OnnxModel, token_table: SymbolTable, batch: dict -) -> List[List[str]]: - """Decode one batch and return the result. - Currently it only greedy_search is supported. - - Args: - model: - The neural model. - token_table: - The token table. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - - Returns: - Return the decoded results for each utterance. - """ - feature = batch["inputs"] - assert feature.ndim == 3 - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(dtype=torch.int64) - - encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) - - hyps = greedy_search( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - hyps = [token_ids_to_words(h).split() for h in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: nn.Module, - token_table: SymbolTable, -) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - The neural model. - token_table: - The token table. - - Returns: - - A list of tuples. Each tuple contains three elements: - - cut_id, - - reference transcript, - - predicted result. - - The total duration (in seconds) of the dataset. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 10 - total_duration = 0 - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return results, total_duration - - -def save_results( - res_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -): - recog_path = res_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - errs_info = res_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("WER", file=f) - print(wer, file=f) - - s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriTTSAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - assert ( - args.decoding_method == "greedy_search" - ), "Only supports greedy_search currently." - res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" - - setup_logger(f"{res_dir}/log-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - logging.info(f"Device: {device}") - - token_table = SymbolTable.from_file(args.tokens) - - logging.info(vars(args)) - - logging.info("About to create model") - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - # we need cut ids to display recognition results. - args.return_cuts = True - libritts = LibriTTSAsrDataModule(args) - - test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) - test_other_cuts = libritts.test_other_cuts().map(normalize_text) - - test_clean_dl = libritts.test_dataloaders(test_clean_cuts) - test_other_dl = libritts.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - start_time = time.time() - results, total_duration = decode_dataset( - dl=test_dl, model=model, token_table=token_table - ) - end_time = time.time() - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - - logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") - logging.info(f"Wave duration: {total_duration:.3f} s") - logging.info( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - - save_results(res_dir=res_dir, test_set_name=test_set, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py deleted file mode 120000 index d623a8462..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 120000 index cfea104c2..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained.py b/egs/libritts/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py deleted file mode 120000 index a3183ebf6..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py deleted file mode 120000 index a4fd76ac2..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py deleted file mode 120000 index f805e3761..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py deleted file mode 120000 index 8343d5079..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py deleted file mode 120000 index 3568e7cab..000000000 --- a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/optim.py b/egs/libritts/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/libritts/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/pretrained.py b/egs/libritts/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/libritts/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/pretrained_ctc.py b/egs/libritts/ASR/zipformer/pretrained_ctc.py deleted file mode 120000 index c2f6f6fc3..000000000 --- a/egs/libritts/ASR/zipformer/pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/scaling.py b/egs/libritts/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/libritts/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/scaling_converter.py b/egs/libritts/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/libritts/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/streaming_beam_search.py b/egs/libritts/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/libritts/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/streaming_decode.py b/egs/libritts/ASR/zipformer/streaming_decode.py deleted file mode 100755 index b21018788..000000000 --- a/egs/libritts/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,901 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import LibriTTSAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet, set_caching_enabled -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params, normalize_text - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--label", - type=str, - default="", - help="""Extra label of the decoding run.""", - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - parser.add_argument( - "--skip-scoring", - type=str2bool, - default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_asr_output( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - """ - Save text produced by ASR. - """ - for key, results in results_dict.items(): - recogs_filename = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recogs_filename, texts=results) - logging.info(f"The transcripts are stored in {recogs_filename}") - - -def save_wer_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - """ - Save WER and per-utterance word alignments. - """ - test_set_wers = dict() - for key, results in results_dict.items(): - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w", encoding="utf8") as fd: - wer = write_error_stats( - fd, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info(f"Wrote detailed error stats to {errs_filename}") - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - - wer_filename = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(wer_filename, "w", encoding="utf8") as fd: - print("settings\tWER", file=fd) - for key, val in test_set_wers: - print(f"{key}\t{val}", file=fd) - - s = f"\nFor {test_set_name}, WER of different settings are:\n" - note = f"\tbest for {test_set_name}" - for key, val in test_set_wers: - s += f"{key}\t{val}{note}\n" - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriTTSAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - # enable AudioCache - set_caching_enabled(True) # lhotse - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"_chunk-{params.chunk_size}" - params.suffix += f"_left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"_beam-{params.beam}" - params.suffix += f"_max-contexts-{params.max_contexts}" - params.suffix += f"_max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - if params.label: - params.suffix += f"-{params.label}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - libritts = LibriTTSAsrDataModule(args) - - test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) - test_other_cuts = libritts.test_other_cuts().map(normalize_text) - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_asr_output( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - if not params.skip_scoring: - save_wer_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libritts/ASR/zipformer/subsampling.py b/egs/libritts/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/libritts/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py deleted file mode 100755 index 5485eaf0a..000000000 --- a/egs/libritts/ASR/zipformer/train.py +++ /dev/null @@ -1,1527 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# Copyright 2024 The Chinese Univ. of HK (author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --full-libri 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` - - ctc loss & attention decoder loss, no transducer loss, - with `--use-transducer False --use-ctc True --use-attention-decoder True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriTTSAsrDataModule -from attention_decoder import AttentionDecoderModel -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--attention-decoder-dim", - type=int, - default=512, - help="""Dimension used in the attention decoder""", - ) - - parser.add_argument( - "--attention-decoder-num-layers", - type=int, - default=6, - help="""Number of transformer layers used in attention decoder""", - ) - - parser.add_argument( - "--attention-decoder-attention-dim", - type=int, - default=512, - help="""Attention dimension used in attention decoder""", - ) - - parser.add_argument( - "--attention-decoder-num-heads", - type=int, - default=8, - help="""Number of attention heads used in attention decoder""", - ) - - parser.add_argument( - "--attention-decoder-feedforward-dim", - type=int, - default=2048, - help="""Feedforward dimension used in attention decoder""", - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - parser.add_argument( - "--use-attention-decoder", - type=str2bool, - default=False, - help="If True, use attention-decoder head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--attention-decoder-loss-scale", - type=float, - default=0.8, - help="Scale for attention-decoder loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--use-bf16", - type=str2bool, - default=False, - help="Whether to use bf16 in AMP.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - # parameters for attention-decoder - "ignore_id": -1, - "label_smoothing": 0.1, - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def normalize_text(c: Cut): - def remove_punc_to_upper(text: str) -> str: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - s_list = [x.upper() if x in tokens else " " for x in text] - s = " ".join("".join(s_list).split()).strip() - return s - - text = remove_punc_to_upper(c.supervisions[0].text) - c.supervisions[0].text = text - return c - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_attention_decoder_model(params: AttributeDict) -> nn.Module: - decoder = AttentionDecoderModel( - vocab_size=params.vocab_size, - decoder_dim=params.attention_decoder_dim, - num_decoder_layers=params.attention_decoder_num_layers, - attention_dim=params.attention_decoder_attention_dim, - num_heads=params.attention_decoder_num_heads, - feedforward_dim=params.attention_decoder_feedforward_dim, - memory_dim=max(_to_int_tuple(params.encoder_dim)), - sos_id=params.sos_id, - eos_id=params.eos_id, - ignore_id=params.ignore_id, - label_smoothing=params.label_smoothing, - ) - return decoder - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - if params.use_attention_decoder: - attention_decoder = get_attention_decoder_model(params) - else: - attention_decoder = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - attention_decoder=attention_decoder, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - use_attention_decoder=params.use_attention_decoder, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - if params.use_attention_decoder: - loss += params.attention_decoder_loss_scale * attention_decoder_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.use_attention_decoder: - info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: - logging.info(f"Caught exception: {e}.") - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_autocast: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_autocast: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.sos_id = params.eos_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - if not params.use_attention_decoder: - params.ctc_loss_scale = 1.0 - else: - assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( - params.ctc_loss_scale, - params.attention_decoder_loss_scale, - ) - - if params.use_bf16: # amp + bf16 - assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" - assert not params.use_fp16, "You can only use either fp16 or bf16" - params.dtype = torch.bfloat16 - params.use_autocast = True - elif params.use_fp16: # amp + fp16 - params.dtype = torch.float16 - params.use_autocast = True - else: # fp32 - params.dtype = torch.float32 - params.use_autocast = False - - logging.info(f"Using dtype={params.dtype}") - logging.info(f"Use AMP={params.use_autocast}") - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - libritts = LibriTTSAsrDataModule(args) - - if params.full_libri: - train_cuts = libritts.train_all_shuf_cuts() - - # previously we used the following code to load all training cuts, - # strictly speaking, shuffled training cuts should be used instead, - # but we leave the code here to demonstrate that there is an option - # like this to combine multiple cutsets - - # train_cuts = libritts.train_clean_100_cuts() - # train_cuts += libritts.train_clean_360_cuts() - # train_cuts += libritts.train_other_500_cuts() - else: - train_cuts = libritts.train_clean_100_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.map(normalize_text) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = libritts.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = libritts.dev_clean_cuts().map(normalize_text) - valid_cuts += libritts.dev_other_cuts().map(normalize_text) - valid_dl = libritts.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - LibriTTSAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libritts/ASR/zipformer/zipformer.py b/egs/libritts/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/libritts/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/libritts/CODEC/encodec/base_discriminators.py b/egs/libritts/CODEC/encodec/base_discriminators.py deleted file mode 100644 index 7bc035554..000000000 --- a/egs/libritts/CODEC/encodec/base_discriminators.py +++ /dev/null @@ -1,251 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio -from einops import rearrange -from modules.conv import NormConv1d, NormConv2d - - -def get_padding(kernel_size, dilation=1) -> int: - return int((kernel_size * dilation - dilation) / 2) - - -def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)): - return ( - ((kernel_size[0] - 1) * dilation[0]) // 2, - ((kernel_size[1] - 1) * dilation[1]) // 2, - ) - - -class DiscriminatorP(nn.Module): - def __init__( - self, - period, - kernel_size=5, - stride=3, - activation: str = "LeakyReLU", - activation_params: dict = {"negative_slope": 0.2}, - ): - super(DiscriminatorP, self).__init__() - - self.period = period - self.activation = getattr(torch.nn, activation)(**activation_params) - self.convs = nn.ModuleList( - [ - NormConv2d( - 1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0) - ), - NormConv2d( - 32, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(5, 1), 0), - ), - NormConv2d( - 32, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(5, 1), 0), - ), - NormConv2d( - 32, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(5, 1), 0), - ), - NormConv2d(32, 32, (kernel_size, 1), 1, padding=(2, 0)), - ] - ) - self.conv_post = NormConv2d(32, 1, (3, 1), 1, padding=(1, 0)) - - def forward(self, x): - fmap = [] - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = self.activation(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(nn.Module): - def __init__( - self, - activation: str = "LeakyReLU", - activation_params: dict = {"negative_slope": 0.2}, - ): - super(DiscriminatorS, self).__init__() - self.activation = getattr(torch.nn, activation)(**activation_params) - self.convs = nn.ModuleList( - [ - NormConv1d(1, 32, 15, 1, padding=7), - NormConv1d(32, 32, 41, 2, groups=4, padding=20), - NormConv1d(32, 32, 41, 2, groups=16, padding=20), - NormConv1d(32, 32, 41, 4, groups=16, padding=20), - NormConv1d(32, 32, 41, 4, groups=16, padding=20), - NormConv1d(32, 32, 41, 1, groups=16, padding=20), - NormConv1d(32, 32, 5, 1, padding=2), - ] - ) - self.conv_post = NormConv1d(32, 1, 3, 1, padding=1) - - def forward(self, x): - fmap = [] - for l in self.convs: - x = l(x) - x = self.activation(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - return x, fmap - - -class DiscriminatorSTFT(nn.Module): - """STFT sub-discriminator. - Args: - filters (int): Number of filters in convolutions - in_channels (int): Number of input channels. Default: 1 - out_channels (int): Number of output channels. Default: 1 - n_fft (int): Size of FFT for each scale. Default: 1024 - hop_length (int): Length of hop between STFT windows for each scale. Default: 256 - kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` - stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` - dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` - win_length (int): Window size for each scale. Default: 1024 - normalized (bool): Whether to normalize by magnitude after stft. Default: True - norm (str): Normalization method. Default: `'weight_norm'` - activation (str): Activation function. Default: `'LeakyReLU'` - activation_params (dict): Parameters to provide to the activation function. - growth (int): Growth factor for the filters. Default: 1 - """ - - def __init__( - self, - n_filters: int, - in_channels: int = 1, - out_channels: int = 1, - n_fft: int = 1024, - hop_length: int = 256, - win_length: int = 1024, - max_filters: int = 1024, - filters_scale: int = 1, - kernel_size: Tuple[int, int] = (3, 9), - dilations: List[int] = [1, 2, 4], - stride: Tuple[int, int] = (1, 2), - normalized: bool = True, - norm: str = "weight_norm", - activation: str = "LeakyReLU", - activation_params: dict = {"negative_slope": 0.2}, - ): - super().__init__() - assert len(kernel_size) == 2 - assert len(stride) == 2 - self.filters = n_filters - self.in_channels = in_channels - self.out_channels = out_channels - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.normalized = normalized - self.activation = getattr(torch.nn, activation)(**activation_params) - self.spec_transform = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window_fn=torch.hann_window, - normalized=self.normalized, - center=False, - pad_mode=None, - power=None, - ) - spec_channels = 2 * self.in_channels - self.convs = nn.ModuleList() - self.convs.append( - NormConv2d( - spec_channels, - self.filters, - kernel_size=kernel_size, - padding=get_2d_padding(kernel_size), - ) - ) - in_chs = min(filters_scale * self.filters, max_filters) - for i, dilation in enumerate(dilations): - out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) - self.convs.append( - NormConv2d( - in_chs, - out_chs, - kernel_size=kernel_size, - stride=stride, - dilation=(dilation, 1), - padding=get_2d_padding(kernel_size, (dilation, 1)), - norm=norm, - ) - ) - in_chs = out_chs - out_chs = min( - (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters - ) - self.convs.append( - NormConv2d( - in_chs, - out_chs, - kernel_size=(kernel_size[0], kernel_size[0]), - padding=get_2d_padding((kernel_size[0], kernel_size[0])), - norm=norm, - ) - ) - self.conv_post = NormConv2d( - out_chs, - self.out_channels, - kernel_size=(kernel_size[0], kernel_size[0]), - padding=get_2d_padding((kernel_size[0], kernel_size[0])), - norm=norm, - ) - - def forward(self, x: torch.Tensor): - fmap = [] - z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] - z = torch.cat([z.real, z.imag], dim=1) - z = rearrange(z, "b c w t -> b c t w") - for i, layer in enumerate(self.convs): - z = layer(z) - z = self.activation(z) - fmap.append(z) - z = self.conv_post(z) - return z, fmap diff --git a/egs/libritts/CODEC/encodec/binary.py b/egs/libritts/CODEC/encodec/binary.py deleted file mode 100644 index 003bcfaf5..000000000 --- a/egs/libritts/CODEC/encodec/binary.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" - -import io -import json -import struct -from typing import IO, Any, List, Optional - -# format is `ECDC` magic code, followed by the header size as uint32. -# Then an uint8 indicates the protocol version (0.) -# The header is then provided as json and should contain all required -# informations for decoding. A raw stream of bytes is then provided -# and should be interpretable using the json header. -_encodec_header_struct = struct.Struct("!4sBI") -_ENCODEC_MAGIC = b"ECDC" - - -def write_ecdc_header(fo: IO[bytes], metadata: Any): - meta_dumped = json.dumps(metadata).encode("utf-8") - version = 0 - header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped)) - fo.write(header) - fo.write(meta_dumped) - fo.flush() - - -def _read_exactly(fo: IO[bytes], size: int) -> bytes: - buf = b"" - while len(buf) < size: - new_buf = fo.read(size) - if not new_buf: - raise EOFError( - "Impossible to read enough data from the stream, " - f"{size} bytes remaining." - ) - buf += new_buf - size -= len(new_buf) - return buf - - -def read_ecdc_header(fo: IO[bytes]): - header_bytes = _read_exactly(fo, _encodec_header_struct.size) - magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) - if magic != _ENCODEC_MAGIC: - raise ValueError("File is not in ECDC format.") - if version != 0: - raise ValueError("Version not supported.") - meta_bytes = _read_exactly(fo, meta_size) - return json.loads(meta_bytes.decode("utf-8")) - - -class BitPacker: - """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. - Note that for some bandwidth (1.5, 3), the codebook representation - will not cover an integer number of bytes. - - Args: - bits (int): number of bits per value that will be pushed. - fo (IO[bytes]): file-object to push the bytes to. - """ - - def __init__(self, bits: int, fo: IO[bytes]): - self._current_value = 0 - self._current_bits = 0 - self.bits = bits - self.fo = fo - - def push(self, value: int): - """Push a new value to the stream. This will immediately - write as many uint8 as possible to the underlying file-object.""" - self._current_value += value << self._current_bits - self._current_bits += self.bits - while self._current_bits >= 8: - lower_8bits = self._current_value & 0xFF - self._current_bits -= 8 - self._current_value >>= 8 - self.fo.write(bytes([lower_8bits])) - - def flush(self): - """Flushes the remaining partial uint8, call this at the end - of the stream to encode.""" - if self._current_bits: - self.fo.write(bytes([self._current_value])) - self._current_value = 0 - self._current_bits = 0 - self.fo.flush() - - -class BitUnpacker: - """BitUnpacker does the opposite of `BitPacker`. - - Args: - bits (int): number of bits of the values to decode. - fo (IO[bytes]): file-object to push the bytes to. - """ - - def __init__(self, bits: int, fo: IO[bytes]): - self.bits = bits - self.fo = fo - self._mask = (1 << bits) - 1 - self._current_value = 0 - self._current_bits = 0 - - def pull(self) -> Optional[int]: - """ - Pull a single value from the stream, potentially reading some - extra bytes from the underlying file-object. - Returns `None` when reaching the end of the stream. - """ - while self._current_bits < self.bits: - buf = self.fo.read(1) - if not buf: - return None - character = buf[0] - self._current_value += character << self._current_bits - self._current_bits += 8 - - out = self._current_value & self._mask - self._current_value >>= self.bits - self._current_bits -= self.bits - return out - - -def test(): - import torch - - torch.manual_seed(1234) - for rep in range(4): - length: int = torch.randint(10, 2_000, (1,)).item() - bits: int = torch.randint(1, 16, (1,)).item() - tokens: List[int] = torch.randint(2**bits, (length,)).tolist() - rebuilt: List[int] = [] - buf = io.BytesIO() - packer = BitPacker(bits, buf) - for token in tokens: - packer.push(token) - packer.flush() - buf.seek(0) - unpacker = BitUnpacker(bits, buf) - while True: - value = unpacker.pull() - if value is None: - break - rebuilt.append(value) - assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) - # The flushing mechanism might lead to "ghost" values at the end of the stream. - assert len(rebuilt) <= len(tokens) + 8 // bits, ( - len(rebuilt), - len(tokens), - bits, - ) - for idx, (a, b) in enumerate(zip(tokens, rebuilt)): - assert a == b, (idx, a, b) - - -if __name__ == "__main__": - test() diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py deleted file mode 100644 index e77a255e5..000000000 --- a/egs/libritts/CODEC/encodec/codec_datamodule.py +++ /dev/null @@ -1,336 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibriTTSCodecDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="Codec data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="""When enabled, use the entire LibriTTS training set. - Otherwise, use the clean-100 subset.""", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=8, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - world_size: Optional[int] = None, - rank: Optional[int] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=False, - return_spk_ids=False, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - world_size=world_size, - rank=rank, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - world_size=world_size, - rank=rank, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders( - self, - cuts_valid: CutSet, - world_size: Optional[int] = None, - rank: Optional[int] = None, - ) -> DataLoader: - logging.info("About to create dev dataset") - - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=False, - return_spk_ids=False, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - world_size=world_size, - rank=rank, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=1, - drop_last=False, - persistent_workers=True, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=False, - return_spk_ids=False, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" - ) diff --git a/egs/libritts/CODEC/encodec/discriminators.py b/egs/libritts/CODEC/encodec/discriminators.py deleted file mode 100644 index e6b7f0929..000000000 --- a/egs/libritts/CODEC/encodec/discriminators.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List - -import torch -import torch.nn as nn -from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT -from torch.nn import AvgPool1d - - -class MultiPeriodDiscriminator(nn.Module): - def __init__(self): - super(MultiPeriodDiscriminator, self).__init__() - self.discriminators = nn.ModuleList( - [ - DiscriminatorP(2), - DiscriminatorP(3), - DiscriminatorP(5), - DiscriminatorP(7), - DiscriminatorP(11), - ] - ) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class MultiScaleDiscriminator(nn.Module): - def __init__(self): - super(MultiScaleDiscriminator, self).__init__() - self.discriminators = nn.ModuleList( - [ - DiscriminatorS(), - DiscriminatorS(), - DiscriminatorS(), - ] - ) - self.meanpools = nn.ModuleList( - [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] - ) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - if i != 0: - y = self.meanpools[i - 1](y) - y_hat = self.meanpools[i - 1](y_hat) - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class MultiScaleSTFTDiscriminator(nn.Module): - """Multi-Scale STFT (MS-STFT) discriminator. - Args: - filters (int): Number of filters in convolutions - in_channels (int): Number of input channels. Default: 1 - out_channels (int): Number of output channels. Default: 1 - n_ffts (Sequence[int]): Size of FFT for each scale - hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale - win_lengths (Sequence[int]): Window size for each scale - **kwargs: additional args for STFTDiscriminator - """ - - def __init__( - self, - n_filters: int, - in_channels: int = 1, - out_channels: int = 1, - n_ffts: List[int] = [1024, 2048, 512, 256, 128], - hop_lengths: List[int] = [256, 512, 128, 64, 32], - win_lengths: List[int] = [1024, 2048, 512, 256, 128], - **kwargs - ): - super().__init__() - assert len(n_ffts) == len(hop_lengths) == len(win_lengths) - self.discriminators = nn.ModuleList( - [ - DiscriminatorSTFT( - n_filters, - in_channels=in_channels, - out_channels=out_channels, - n_fft=n_ffts[i], - win_length=win_lengths[i], - hop_length=hop_lengths[i], - **kwargs - ) - for i in range(len(n_ffts)) - ] - ) - self.num_discriminators = len(self.discriminators) - - def forward(self, x: torch.Tensor): - logits = [] - fmaps = [] - for disc in self.discriminators: - logit, fmap = disc(x) - logits.append(logit) - fmaps.append(fmap) - return logits, fmaps diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py deleted file mode 100644 index f21d494b6..000000000 --- a/egs/libritts/CODEC/encodec/encodec.py +++ /dev/null @@ -1,359 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import random -from typing import List, Optional - -import numpy as np -import torch -from loss import ( - DiscriminatorAdversarialLoss, - FeatureLoss, - GeneratorAdversarialLoss, - MelSpectrogramReconstructionLoss, - WavReconstructionLoss, -) -from torch import nn -from torch.cuda.amp import autocast - - -class Encodec(nn.Module): - def __init__( - self, - sampling_rate: int, - target_bandwidths: List[float], - params: dict, - encoder: nn.Module, - quantizer: nn.Module, - decoder: nn.Module, - multi_scale_discriminator: nn.Module, - multi_period_discriminator: Optional[nn.Module] = None, - multi_scale_stft_discriminator: Optional[nn.Module] = None, - cache_generator_outputs: bool = False, - ): - super(Encodec, self).__init__() - - self.params = params - - # setup the generator - self.sampling_rate = sampling_rate - self.encoder = encoder - self.quantizer = quantizer - self.decoder = decoder - - self.ratios = encoder.ratios - self.hop_length = np.prod(self.ratios) - self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios)) - self.target_bandwidths = target_bandwidths - - # discriminators - self.multi_scale_discriminator = multi_scale_discriminator - self.multi_period_discriminator = multi_period_discriminator - self.multi_scale_stft_discriminator = multi_scale_stft_discriminator - - # cache - self.cache_generator_outputs = cache_generator_outputs - self._cache = None - - # construct loss functions - self.generator_adversarial_loss = GeneratorAdversarialLoss( - average_by_discriminators=True, loss_type="hinge" - ) - self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( - average_by_discriminators=True, loss_type="hinge" - ) - self.feature_match_loss = FeatureLoss() - self.wav_reconstruction_loss = WavReconstructionLoss() - self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( - sampling_rate=self.sampling_rate - ) - - def _forward_generator( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - return_sample: bool = False, - ): - """Perform generator forward. - - Args: - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - return_sample (bool): Return the generator output. - - Returns: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - """ - # setup - speech = speech.unsqueeze(1) - - # calculate generator outputs - reuse_cache = True - if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False - e = self.encoder(speech) - index = torch.tensor( - random.randint(0, len(self.target_bandwidths) - 1), - device=speech.device, - ) - if torch.distributed.is_initialized(): - torch.distributed.broadcast(index, src=0) - bw = self.target_bandwidths[index.item()] - quantized, codes, bandwidth, commit_loss = self.quantizer( - e, self.frame_rate, bw - ) - speech_hat = self.decoder(quantized) - else: - speech_hat = self._cache - # store cache - if self.training and self.cache_generator_outputs and not reuse_cache: - self._cache = speech_hat - - # calculate discriminator outputs - y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous()) - with torch.no_grad(): - # do not store discriminator gradient in generator turn - y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) - - gen_period_adv_loss = torch.tensor(0.0) - feature_period_loss = torch.tensor(0.0) - if self.multi_period_discriminator is not None: - y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( - speech.contiguous(), - speech_hat.contiguous(), - ) - - gen_scale_adv_loss = torch.tensor(0.0) - feature_scale_loss = torch.tensor(0.0) - if self.multi_scale_discriminator is not None: - y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( - speech.contiguous(), - speech_hat.contiguous(), - ) - - # calculate losses - with autocast(enabled=False): - gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) - - if self.multi_period_discriminator is not None: - gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) - if self.multi_scale_discriminator is not None: - gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) - - feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat) - - if self.multi_period_discriminator is not None: - feature_period_loss = self.feature_match_loss( - feats=fmap_p, feats_hat=fmap_p_hat - ) - if self.multi_scale_discriminator is not None: - feature_scale_loss = self.feature_match_loss( - feats=fmap_s, feats_hat=fmap_s_hat - ) - - wav_reconstruction_loss = self.wav_reconstruction_loss( - x=speech, x_hat=speech_hat - ) - mel_reconstruction_loss = self.mel_reconstruction_loss( - x=speech, x_hat=speech_hat - ) - - stats = dict( - generator_wav_reconstruction_loss=wav_reconstruction_loss.item(), - generator_mel_reconstruction_loss=mel_reconstruction_loss.item(), - generator_feature_stft_loss=feature_stft_loss.item(), - generator_feature_period_loss=feature_period_loss.item(), - generator_feature_scale_loss=feature_scale_loss.item(), - generator_stft_adv_loss=gen_stft_adv_loss.item(), - generator_period_adv_loss=gen_period_adv_loss.item(), - generator_scale_adv_loss=gen_scale_adv_loss.item(), - generator_commit_loss=commit_loss.item(), - ) - - if return_sample: - stats["returned_sample"] = ( - speech_hat.cpu(), - speech.cpu(), - fmap_hat[0][0].data.cpu(), - fmap[0][0].data.cpu(), - ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - return ( - commit_loss, - gen_stft_adv_loss, - gen_period_adv_loss, - gen_scale_adv_loss, - feature_stft_loss, - feature_period_loss, - feature_scale_loss, - wav_reconstruction_loss, - mel_reconstruction_loss, - stats, - ) - - def _forward_discriminator( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - ): - """ - Args: - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - - Returns: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - """ - # setup - speech = speech.unsqueeze(1) - - # calculate generator outputs - reuse_cache = True - if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False - e = self.encoder(speech) - index = torch.tensor( - random.randint(0, len(self.target_bandwidths) - 1), - device=speech.device, - ) - if torch.distributed.is_initialized(): - torch.distributed.broadcast(index, src=0) - bw = self.target_bandwidths[index.item()] - quantized, codes, bandwidth, commit_loss = self.quantizer( - e, self.frame_rate, bw - ) - speech_hat = self.decoder(quantized) - else: - speech_hat = self._cache - - # store cache - if self.training and self.cache_generator_outputs and not reuse_cache: - self._cache = speech_hat - - # calculate discriminator outputs - y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) - y_hat, fmap_hat = self.multi_scale_stft_discriminator( - speech_hat.contiguous().detach() - ) - - disc_period_real_adv_loss = torch.tensor(0.0) - disc_period_fake_adv_loss = torch.tensor(0.0) - if self.multi_period_discriminator is not None: - y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( - speech.contiguous(), - speech_hat.contiguous().detach(), - ) - - disc_scale_real_adv_loss = torch.tensor(0.0) - disc_scale_fake_adv_loss = torch.tensor(0.0) - if self.multi_scale_discriminator is not None: - y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( - speech.contiguous(), - speech_hat.contiguous().detach(), - ) - # calculate losses - with autocast(enabled=False): - ( - disc_stft_real_adv_loss, - disc_stft_fake_adv_loss, - ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat) - if self.multi_period_discriminator is not None: - ( - disc_period_real_adv_loss, - disc_period_fake_adv_loss, - ) = self.discriminator_adversarial_loss( - outputs=y_p, outputs_hat=y_p_hat - ) - if self.multi_scale_discriminator is not None: - ( - disc_scale_real_adv_loss, - disc_scale_fake_adv_loss, - ) = self.discriminator_adversarial_loss( - outputs=y_s, outputs_hat=y_s_hat - ) - - stats = dict( - discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(), - discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(), - discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(), - discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(), - discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(), - discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(), - ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - - return ( - disc_stft_real_adv_loss, - disc_stft_fake_adv_loss, - disc_period_real_adv_loss, - disc_period_fake_adv_loss, - disc_scale_real_adv_loss, - disc_scale_fake_adv_loss, - stats, - ) - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - return_sample: bool, - forward_generator: bool, - ): - if forward_generator: - return self._forward_generator( - speech=speech, - speech_lengths=speech_lengths, - return_sample=return_sample, - ) - else: - return self._forward_discriminator( - speech=speech, - speech_lengths=speech_lengths, - ) - - def encode(self, x, target_bw=None, st=None): - e = self.encoder(x) - if target_bw is None: - bw = self.target_bandwidths[-1] - else: - bw = target_bw - if st is None: - st = 0 - codes = self.quantizer.encode(e, self.frame_rate, bw, st) - return codes - - def decode(self, codes): - quantized = self.quantizer.decode(codes) - x_hat = self.decoder(quantized) - return x_hat - - def inference(self, x, target_bw=None, st=None): - # setup - x = x.unsqueeze(1) - - codes = self.encode(x, target_bw, st) - x_hat = self.decode(codes) - return codes, x_hat diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py deleted file mode 100755 index 3c6ea15f9..000000000 --- a/egs/libritts/CODEC/encodec/infer.py +++ /dev/null @@ -1,352 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script performs model inference on test set. - -Usage: -./codec/infer.py \ - --epoch 300 \ - --exp-dir ./codec/exp \ - --max-duration 500 -""" - - -import argparse -import logging -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path -from statistics import mean -from typing import List, Tuple - -import numpy as np -import torch -import torchaudio -from codec_datamodule import LibriTTSCodecDataModule -from pesq import pesq -from pystoi import stoi -from scipy import signal -from torch import nn -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, setup_logger - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="encodec/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--target-bw", - type=float, - default=24, - help="The target bandwidth for the generator", - ) - - return parser - - -# implementation from https://github.com/yangdongchao/AcademiCodec/blob/master/academicodec/models/encodec/test.py -def remove_encodec_weight_norm(model) -> None: - from modules import SConv1d - from modules.seanet import SConvTranspose1d, SEANetResnetBlock - from torch.nn.utils import remove_weight_norm - - encoder = model.encoder.model - for key in encoder._modules: - if isinstance(encoder._modules[key], SEANetResnetBlock): - remove_weight_norm(encoder._modules[key].shortcut.conv.conv) - block_modules = encoder._modules[key].block._modules - for skey in block_modules: - if isinstance(block_modules[skey], SConv1d): - remove_weight_norm(block_modules[skey].conv.conv) - elif isinstance(encoder._modules[key], SConv1d): - remove_weight_norm(encoder._modules[key].conv.conv) - - decoder = model.decoder.model - for key in decoder._modules: - if isinstance(decoder._modules[key], SEANetResnetBlock): - remove_weight_norm(decoder._modules[key].shortcut.conv.conv) - block_modules = decoder._modules[key].block._modules - for skey in block_modules: - if isinstance(block_modules[skey], SConv1d): - remove_weight_norm(block_modules[skey].conv.conv) - elif isinstance(decoder._modules[key], SConvTranspose1d): - remove_weight_norm(decoder._modules[key].convtr.convtr) - elif isinstance(decoder._modules[key], SConv1d): - remove_weight_norm(decoder._modules[key].conv.conv) - - -def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float: - """Compute PESQ score between reference and generated audio.""" - DEFAULT_SAMPLING_RATE = 16000 - ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE) - deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE) - return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb") - - -def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float: - """Compute STOI score between reference and generated audio.""" - return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False) - - -def infer_dataset( - dl: torch.utils.data.DataLoader, - subset: str, - params: AttributeDict, - model: nn.Module, -) -> Tuple[float, float]: - """Decode dataset. - The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - subset: - The name of the subset. - params: - It is returned by :func:`get_params`. - model: - The neural model. - - Returns: - The average PESQ and STOI scores. - """ - - # Background worker save audios to disk. - def _save_worker( - subset: str, - batch_size: int, - cut_ids: List[str], - audio: torch.Tensor, - audio_pred: torch.Tensor, - audio_lens: List[int], - ): - for i in range(batch_size): - torchaudio.save( - str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), - audio[i : i + 1, : audio_lens[i]], - sample_rate=params.sampling_rate, - ) - torchaudio.save( - str(params.save_wav_dir / subset / f"{cut_ids[i]}_recon.wav"), - audio_pred[i : i + 1, : audio_lens[i]], - sample_rate=params.sampling_rate, - ) - - device = next(model.parameters()).device - num_cuts = 0 - log_interval = 5 - - pesq_wb_scores = [] - stoi_scores = [] - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - futures = [] - with ThreadPoolExecutor(max_workers=1) as executor: - for batch_idx, batch in enumerate(dl): - batch_size = len(batch["audio"]) - - audios = batch["audio"] - audio_lens = batch["audio_lens"].tolist() - cut_ids = [cut.id for cut in batch["cut"]] - - codes, audio_hats = model.inference( - audios.to(device), target_bw=params.target_bw - ) - audio_hats = audio_hats.squeeze(1).cpu() - - for cut_id, audio, audio_hat, audio_len in zip( - cut_ids, audios, audio_hats, audio_lens - ): - try: - pesq_wb = compute_pesq( - ref_wav=audio[:audio_len].numpy(), - gen_wav=audio_hat[:audio_len].numpy(), - ) - pesq_wb_scores.append(pesq_wb) - except Exception as e: - logging.error(f"Error while computing PESQ for cut {cut_id}: {e}") - - stoi_score = compute_stoi( - ref_wav=audio[:audio_len].numpy(), - gen_wav=audio_hat[:audio_len].numpy(), - sampling_rate=params.sampling_rate, - ) - stoi_scores.append(stoi_score) - - futures.append( - executor.submit( - _save_worker, - subset, - batch_size, - cut_ids, - audios, - audio_hats, - audio_lens, - ) - ) - - num_cuts += batch_size - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - # return results - for f in futures: - f.result() - return mean(pesq_wb_scores), mean(stoi_scores) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriTTSCodecDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.suffix = f"epoch-{params.epoch}" - - params.res_dir = params.exp_dir / "infer" / params.suffix - params.save_wav_dir = params.res_dir / "wav" - params.save_wav_dir.mkdir(parents=True, exist_ok=True) - - setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") - logging.info("Infer started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - # we need cut ids to display results of both constructed and ground-truth audio - args.return_cuts = True - libritts = LibriTTSCodecDataModule(args) - - logging.info(f"Device: {device}") - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - remove_encodec_weight_norm(model) - - model.to(device) - model.eval() - - encoder = model.encoder - decoder = model.decoder - quantizer = model.quantizer - multi_scale_discriminator = model.multi_scale_discriminator - multi_period_discriminator = model.multi_period_discriminator - multi_scale_stft_discriminator = model.multi_scale_stft_discriminator - - num_param_e = sum([p.numel() for p in encoder.parameters()]) - logging.info(f"Number of parameters in encoder: {num_param_e}") - num_param_d = sum([p.numel() for p in decoder.parameters()]) - logging.info(f"Number of parameters in decoder: {num_param_d}") - num_param_q = sum([p.numel() for p in quantizer.parameters()]) - logging.info(f"Number of parameters in quantizer: {num_param_q}") - num_param_ds = ( - sum([p.numel() for p in multi_scale_discriminator.parameters()]) - if multi_scale_discriminator is not None - else 0 - ) - logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") - num_param_dp = ( - sum([p.numel() for p in multi_period_discriminator.parameters()]) - if multi_period_discriminator is not None - else 0 - ) - logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") - num_param_dstft = sum( - [p.numel() for p in multi_scale_stft_discriminator.parameters()] - ) - logging.info( - f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" - ) - logging.info( - f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" - ) - - test_clean_cuts = libritts.test_clean_cuts() - test_clean = libritts.test_dataloaders(test_clean_cuts) - - test_other_cuts = libritts.test_other_cuts() - test_other = libritts.test_dataloaders(test_other_cuts) - - dev_clean_cuts = libritts.dev_clean_cuts() - dev_clean = libritts.valid_dataloaders(dev_clean_cuts) - - dev_other_cuts = libritts.dev_other_cuts() - dev_other = libritts.valid_dataloaders(dev_other_cuts) - - infer_sets = { - "test-clean": test_clean, - "test-other": test_other, - "dev-clean": dev_clean, - "dev-other": dev_other, - } - - for subset, dl in infer_sets.items(): - save_wav_dir = params.res_dir / "wav" / subset - save_wav_dir.mkdir(parents=True, exist_ok=True) - - logging.info(f"Processing {subset} set, saving to {save_wav_dir}") - - pesq_wb, stoi = infer_dataset( - dl=dl, - subset=subset, - params=params, - model=model, - ) - logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}") - - logging.info(f"Wav files are saved to {params.save_wav_dir}") - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py deleted file mode 100644 index 9cf1d42d2..000000000 --- a/egs/libritts/CODEC/encodec/loss.py +++ /dev/null @@ -1,321 +0,0 @@ -# Modified from egs/ljspeech/TTS/vits/loss.py by: Zengrui JIN (Tsinghua University) -# original implementation is from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Encodec-related loss modules. - -This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. - -""" - -from typing import List, Tuple, Union - -import torch -import torch.nn.functional as F -from torchaudio.transforms import MelSpectrogram - - -class GeneratorAdversarialLoss(torch.nn.Module): - """Generator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators: bool = True, - loss_type: str = "hinge", - ): - """Initialize GeneratorAversarialLoss module. - - Args: - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - loss_type (str): Loss type, "mse" or "hinge". - - """ - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." - if loss_type == "mse": - self.criterion = self._mse_loss - else: - self.criterion = self._hinge_loss - - def forward( - self, - outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - ) -> torch.Tensor: - """Calcualate generator adversarial loss. - - Args: - outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs.. - - Returns: - Tensor: Generator adversarial loss value. - - """ - adv_loss = 0.0 - if isinstance(outputs, (tuple, list)): - for i, outputs_ in enumerate(outputs): - if isinstance(outputs_, (tuple, list)): - # NOTE(kan-bayashi): case including feature maps - outputs_ = outputs_[-1] - adv_loss += self.criterion(outputs_) - if self.average_by_discriminators: - adv_loss /= i + 1 - else: - for i, outputs_ in enumerate(outputs): - adv_loss += self.criterion(outputs_) - adv_loss /= i + 1 - return adv_loss - - def _mse_loss(self, x): - return F.mse_loss(x, x.new_ones(x.size())) - - def _hinge_loss(self, x): - return F.relu(1 - x).mean() - - -class DiscriminatorAdversarialLoss(torch.nn.Module): - """Discriminator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators: bool = True, - loss_type: str = "hinge", - ): - """Initialize DiscriminatorAversarialLoss module. - - Args: - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - loss_type (str): Loss type, "mse" or "hinge". - - """ - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." - if loss_type == "mse": - self.fake_criterion = self._mse_fake_loss - self.real_criterion = self._mse_real_loss - else: - self.fake_criterion = self._hinge_fake_loss - self.real_criterion = self._hinge_real_loss - - def forward( - self, - outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calcualate discriminator adversarial loss. - - Args: - outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs calculated from generator. - outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs calculated from groundtruth. - - Returns: - Tensor: Discriminator real loss value. - Tensor: Discriminator fake loss value. - - """ - real_loss = 0.0 - fake_loss = 0.0 - if isinstance(outputs, (tuple, list)): - for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): - if isinstance(outputs_hat_, (tuple, list)): - # NOTE(kan-bayashi): case including feature maps - outputs_hat_ = outputs_hat_[-1] - outputs_ = outputs_[-1] - real_loss += self.real_criterion(outputs_) - fake_loss += self.fake_criterion(outputs_hat_) - if self.average_by_discriminators: - fake_loss /= i + 1 - real_loss /= i + 1 - else: - for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): - real_loss += self.real_criterion(outputs_) - fake_loss += self.fake_criterion(outputs_hat_) - fake_loss /= i + 1 - real_loss /= i + 1 - - return real_loss, fake_loss - - def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, x.new_ones(x.size())) - - def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, x.new_zeros(x.size())) - - def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.relu(torch.ones_like(x) - x).mean() - - def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.relu(torch.ones_like(x) + x).mean() - - -class FeatureLoss(torch.nn.Module): - """Feature loss module.""" - - def __init__( - self, - average_by_layers: bool = True, - average_by_discriminators: bool = True, - include_final_outputs: bool = True, - ): - """Initialize FeatureMatchLoss module. - - Args: - average_by_layers (bool): Whether to average the loss by the number - of layers. - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - include_final_outputs (bool): Whether to include the final output of - each discriminator for loss calculation. - - """ - super().__init__() - self.average_by_layers = average_by_layers - self.average_by_discriminators = average_by_discriminators - self.include_final_outputs = include_final_outputs - - def forward( - self, - feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], - feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], - ) -> torch.Tensor: - """Calculate feature matching loss. - - Args: - feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of - discriminator outputs or list of discriminator outputs calcuated - from generator's outputs. - feats (Union[List[List[Tensor]], List[Tensor]]): List of list of - discriminator outputs or list of discriminator outputs calcuated - from groundtruth.. - - Returns: - Tensor: Feature matching loss value. - - """ - feat_match_loss = 0.0 - for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): - feat_match_loss_ = 0.0 - if not self.include_final_outputs: - feats_hat_ = feats_hat_[:-1] - feats_ = feats_[:-1] - for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): - feat_match_loss_ += ( - F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean()) - ).mean() - if self.average_by_layers: - feat_match_loss_ /= j + 1 - feat_match_loss += feat_match_loss_ - if self.average_by_discriminators: - feat_match_loss /= i + 1 - - return feat_match_loss - - -class MelSpectrogramReconstructionLoss(torch.nn.Module): - """Mel Spec Reconstruction loss.""" - - def __init__( - self, - sampling_rate: int = 22050, - n_mels: int = 64, - use_fft_mag: bool = True, - return_mel: bool = False, - ): - super().__init__() - self.wav_to_specs = [] - for i in range(5, 12): - s = 2**i - self.wav_to_specs.append( - MelSpectrogram( - sample_rate=sampling_rate, - n_fft=max(s, 512), - win_length=s, - hop_length=s // 4, - n_mels=n_mels, - ) - ) - self.return_mel = return_mel - - def forward( - self, - x_hat: torch.Tensor, - x: torch.Tensor, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: - """Calculate Mel-spectrogram loss. - - Args: - x_hat (Tensor): Generated waveform tensor (B, 1, T). - x (Tensor): Groundtruth waveform tensor (B, 1, T). - spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor - (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth - waveform. - - Returns: - Tensor: Mel-spectrogram loss value. - - """ - mel_loss = 0.0 - - for i, wav_to_spec in enumerate(self.wav_to_specs): - s = 2 ** (i + 5) - wav_to_spec.to(x.device) - - mel_hat = wav_to_spec(x_hat.squeeze(1)) - mel = wav_to_spec(x.squeeze(1)) - - mel_loss += ( - F.l1_loss(mel_hat, mel, reduce=True, reduction="mean") - + ( - ( - (torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7)) - ** 2 - ).mean(dim=-2) - ** 0.5 - ).mean() - ) - - # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) - # mel = self.wav_to_spec(x.squeeze(1)) - # mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel) - - if self.return_mel: - return mel_loss, (mel_hat, mel) - - return mel_loss - - -class WavReconstructionLoss(torch.nn.Module): - """Wav Reconstruction loss.""" - - def __init__(self): - super().__init__() - - def forward( - self, - x_hat: torch.Tensor, - x: torch.Tensor, - ) -> torch.Tensor: - """Calculate wav loss. - - Args: - x_hat (Tensor): Generated waveform tensor (B, 1, T). - x (Tensor): Groundtruth waveform tensor (B, 1, T). - - Returns: - Tensor: Wav loss value. - - """ - wav_loss = F.l1_loss(x, x_hat) - - return wav_loss diff --git a/egs/libritts/CODEC/encodec/modules/__init__.py b/egs/libritts/CODEC/encodec/modules/__init__.py deleted file mode 100644 index b903a28b0..000000000 --- a/egs/libritts/CODEC/encodec/modules/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Torch modules.""" -# flake8: noqa -from .conv import ( - NormConv1d, - NormConv2d, - NormConvTranspose1d, - NormConvTranspose2d, - SConv1d, - SConvTranspose1d, - pad1d, - unpad1d, -) -from .lstm import SLSTM -from .seanet import SEANetDecoder, SEANetEncoder -from .transformer import StreamingTransformerEncoder diff --git a/egs/libritts/CODEC/encodec/modules/conv.py b/egs/libritts/CODEC/encodec/modules/conv.py deleted file mode 100644 index a70a5c67f..000000000 --- a/egs/libritts/CODEC/encodec/modules/conv.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Convolutional layers wrappers and utilities.""" -import logging -import math -from typing import Any, Dict, Tuple - -from torch import Tensor, nn -from torch.nn import functional as F -from torch.nn.utils import spectral_norm, weight_norm - -from .norm import ConvLayerNorm - -CONV_NORMALIZATIONS = frozenset( - [ - "none", - "weight_norm", - "spectral_norm", - "time_layer_norm", - "layer_norm", - "time_group_norm", - ] -) - - -def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: - assert norm in CONV_NORMALIZATIONS - if norm == "weight_norm": - return weight_norm(module) - elif norm == "spectral_norm": - return spectral_norm(module) - else: - # We already check was in CONV_NORMALIZATION, so any other choice - # doesn't need reparametrization. - return module - - -def get_norm_module( - module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs -) -> nn.Module: - """Return the proper normalization module. If causal is True, this will ensure the returned - module is causal, or return an error if the normalization doesn't support causal evaluation. - """ - assert norm in CONV_NORMALIZATIONS - if norm == "layer_norm": - assert isinstance(module, nn.modules.conv._ConvNd) - return ConvLayerNorm(module.out_channels, **norm_kwargs) - elif norm == "time_group_norm": - if causal: - raise ValueError("GroupNorm doesn't support causal evaluation.") - assert isinstance(module, nn.modules.conv._ConvNd) - return nn.GroupNorm(1, module.out_channels, **norm_kwargs) - else: - return nn.Identity() - - -def get_extra_padding_for_conv1d( - x: Tensor, kernel_size: int, stride: int, padding_total: int = 0 -) -> int: - """See `pad_for_conv1d`.""" - length = x.shape[-1] - n_frames = (length - kernel_size + padding_total) / stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) - return ideal_length - length - - -def pad_for_conv1d(x: Tensor, kernel_size: int, stride: int, padding_total: int = 0): - """Pad for a convolution to make sure that the last window is full. - Extra padding is added at the end. This is required to ensure that we can rebuild - an output of the same length, as otherwise, even with padding, some time steps - might get removed. - For instance, with total padding = 4, kernel size = 4, stride = 2: - 0 0 1 2 3 4 5 0 0 # (0s are padding) - 1 2 3 # (output frames of a convolution, last 0 is never used) - 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) - 1 2 3 4 # once you removed padding, we are missing one time step ! - """ - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) - return F.pad(x, (0, extra_padding)) - - -def pad1d( - x: Tensor, - paddings: Tuple[int, int], - mode: str = "zero", - value: float = 0.0, -): - """Tiny wrapper around F.pad, just to allow for reflect padding on small input. - If this is the case, we insert extra 0 padding to the right before the reflection happen. - """ - length = x.shape[-1] - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == "reflect": - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: - return F.pad(x, paddings, mode, value) - - -def unpad1d(x: Tensor, paddings: Tuple[int, int]): - """Remove padding from x, handling properly zero padding. Only for 1d!""" - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - assert (padding_left + padding_right) <= x.shape[-1] - end = x.shape[-1] - padding_right - return x[..., padding_left:end] - - -class NormConv1d(nn.Module): - """Wrapper around Conv1d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - - def __init__( - self, - *args, - causal: bool = False, - norm: str = "none", - norm_kwargs: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__() - self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) - self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) - self.norm_type = norm - - def forward(self, x): - x = self.conv(x) - x = self.norm(x) - return x - - -class NormConv2d(nn.Module): - """Wrapper around Conv2d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - - def __init__( - self, - *args, - norm: str = "none", - norm_kwargs: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__() - self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) - self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) - self.norm_type = norm - - def forward(self, x): - x = self.conv(x) - x = self.norm(x) - return x - - -class NormConvTranspose1d(nn.Module): - """Wrapper around ConvTranspose1d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - - def __init__( - self, - *args, - causal: bool = False, - norm: str = "none", - norm_kwargs: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__() - self.convtr = apply_parametrization_norm( - nn.ConvTranspose1d(*args, **kwargs), norm - ) - self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) - self.norm_type = norm - - def forward(self, x): - x = self.convtr(x) - x = self.norm(x) - return x - - -class NormConvTranspose2d(nn.Module): - """Wrapper around ConvTranspose2d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - - def __init__( - self, - *args, - norm: str = "none", - norm_kwargs: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__() - self.convtr = apply_parametrization_norm( - nn.ConvTranspose2d(*args, **kwargs), norm - ) - self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) - - def forward(self, x): - x = self.convtr(x) - x = self.norm(x) - return x - - -class SConv1d(nn.Module): - """Conv1d with some builtin handling of asymmetric or causal padding - and normalization. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - causal: bool = False, - norm: str = "none", - norm_kwargs: Dict[str, Any] = {}, - pad_mode: str = "reflect", - ): - super().__init__() - # warn user on unusual setup between dilation and stride - if stride > 1 and dilation > 1: - logging.warning( - "SConv1d has been initialized with stride > 1 and dilation > 1" - f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." - ) - self.conv = NormConv1d( - in_channels, - out_channels, - kernel_size, - stride, - dilation=dilation, - groups=groups, - bias=bias, - causal=causal, - norm=norm, - norm_kwargs=norm_kwargs, - ) - self.causal = causal - self.pad_mode = pad_mode - - def forward(self, x): - B, C, T = x.shape - kernel_size = self.conv.conv.kernel_size[0] - stride = self.conv.conv.stride[0] - dilation = self.conv.conv.dilation[0] - padding_total = (kernel_size - 1) * dilation - (stride - 1) - extra_padding = get_extra_padding_for_conv1d( - x, kernel_size, stride, padding_total - ) - if self.causal: - # Left padding for causal - x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - x = pad1d( - x, (padding_left, padding_right + extra_padding), mode=self.pad_mode - ) - return self.conv(x) - - -class SConvTranspose1d(nn.Module): - """ConvTranspose1d with some builtin handling of asymmetric or causal padding - and normalization. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - causal: bool = False, - norm: str = "none", - trim_right_ratio: float = 1.0, - norm_kwargs: Dict[str, Any] = {}, - ): - super().__init__() - self.convtr = NormConvTranspose1d( - in_channels, - out_channels, - kernel_size, - stride, - causal=causal, - norm=norm, - norm_kwargs=norm_kwargs, - ) - self.causal = causal - self.trim_right_ratio = trim_right_ratio - assert ( - self.causal or self.trim_right_ratio == 1.0 - ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" - assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 - - def forward(self, x): - kernel_size = self.convtr.convtr.kernel_size[0] - stride = self.convtr.convtr.stride[0] - padding_total = kernel_size - stride - - y = self.convtr(x) - - # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be - # removed at the very end, when keeping only the right length for the output, - # as removing it here would require also passing the length at the matching layer - # in the encoder. - if self.causal: - # Trim the padding on the right according to the specified ratio - # if trim_right_ratio = 1.0, trim everything from right - padding_right = math.ceil(padding_total * self.trim_right_ratio) - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - return y diff --git a/egs/libritts/CODEC/encodec/modules/lstm.py b/egs/libritts/CODEC/encodec/modules/lstm.py deleted file mode 100644 index 5307552c0..000000000 --- a/egs/libritts/CODEC/encodec/modules/lstm.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""LSTM layers module.""" -from torch import nn - - -class SLSTM(nn.Module): - """ - LSTM without worrying about the hidden state, nor the layout of the data. - Expects input as convolutional layout. - """ - - def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): - super().__init__() - self.skip = skip - self.lstm = nn.LSTM(dimension, dimension, num_layers) - - def forward(self, x): - x = x.permute(2, 0, 1) - y, _ = self.lstm(x) - if self.skip: - y = y + x - y = y.permute(1, 2, 0) - return y diff --git a/egs/libritts/CODEC/encodec/modules/norm.py b/egs/libritts/CODEC/encodec/modules/norm.py deleted file mode 100644 index 3002b3a26..000000000 --- a/egs/libritts/CODEC/encodec/modules/norm.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Normalization modules.""" - -from typing import List, Union - -import einops -import torch -from torch import nn - - -class ConvLayerNorm(nn.LayerNorm): - """ - Convolution-friendly LayerNorm that moves channels to last dimensions - before running the normalization and moves them back to original position right after. - """ - - def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs): - super().__init__(normalized_shape, **kwargs) - - def forward(self, x): - x = einops.rearrange(x, "b ... t -> b t ...") - x = super().forward(x) - x = einops.rearrange(x, "b t ... -> b ... t") - return diff --git a/egs/libritts/CODEC/encodec/modules/seanet.py b/egs/libritts/CODEC/encodec/modules/seanet.py deleted file mode 100644 index 76999b298..000000000 --- a/egs/libritts/CODEC/encodec/modules/seanet.py +++ /dev/null @@ -1,368 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Encodec SEANet-based encoder and decoder implementation.""" - -from typing import Any, Dict, List, Optional - -import numpy as np -import torch.nn as nn -from modules import SLSTM, SConv1d, SConvTranspose1d - - -class SEANetResnetBlock(nn.Module): - """Residual block from SEANet model. - Args: - dim (int): Dimension of the input/output - kernel_sizes (list): List of kernel sizes for the convolutions. - dilations (list): List of dilations for the convolutions. - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function - norm (str): Normalization method. - norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. - causal (bool): Whether to use fully causal convolution. - pad_mode (str): Padding mode for the convolutions. - compress (int): Reduced dimensionality in residual branches (from Demucs v3) - true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. - """ - - def __init__( - self, - dim: int, - kernel_sizes: List[int] = [3, 1], - dilations: List[int] = [1, 1], - activation: str = "ELU", - activation_params: Dict = {"alpha": 1.0}, - norm: str = "weight_norm", - norm_params: Dict[str, Any] = {}, - causal: bool = False, - pad_mode: str = "reflect", - compress: int = 2, - true_skip: bool = True, - ): - super().__init__() - assert len(kernel_sizes) == len( - dilations - ), "Number of kernel sizes should match number of dilations" - act = getattr(nn, activation) - hidden = dim // compress - block = [] - for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): - in_chs = dim if i == 0 else hidden - out_chs = dim if i == len(kernel_sizes) - 1 else hidden - block += [ - act(**activation_params), - SConv1d( - in_chs, - out_chs, - kernel_size=kernel_size, - dilation=dilation, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - pad_mode=pad_mode, - ), - ] - self.block = nn.Sequential(*block) - self.shortcut: nn.Module - if true_skip: - self.shortcut = nn.Identity() - else: - self.shortcut = SConv1d( - dim, - dim, - kernel_size=1, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - pad_mode=pad_mode, - ) - - def forward(self, x): - return self.shortcut(x) + self.block(x) - - -class SEANetEncoder(nn.Module): - """SEANet encoder. - Args: - channels (int): Audio channels. - dimension (int): Intermediate representation dimension. - n_filters (int): Base width for the model. - n_residual_layers (int): nb of residual layers. - ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of - upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here - that must match the decoder order - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function - norm (str): Normalization method. - norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. - kernel_size (int): Kernel size for the initial convolution. - last_kernel_size (int): Kernel size for the initial convolution. - residual_kernel_size (int): Kernel size for the residual layers. - dilation_base (int): How much to increase the dilation with each layer. - causal (bool): Whether to use fully causal convolution. - pad_mode (str): Padding mode for the convolutions. - true_skip (bool): Whether to use true skip connection or a simple - (streamable) convolution as the skip connection in the residual network blocks. - compress (int): Reduced dimensionality in residual branches (from Demucs v3). - lstm (int): Number of LSTM layers at the end of the encoder. - """ - - def __init__( - self, - channels: int = 1, - dimension: int = 128, - n_filters: int = 32, - n_residual_layers: int = 1, - ratios: List[int] = [8, 5, 4, 2], - activation: str = "ELU", - activation_params: dict = {"alpha": 1.0}, - norm: str = "weight_norm", - norm_params: Dict[str, Any] = {}, - kernel_size: int = 7, - last_kernel_size: int = 7, - residual_kernel_size: int = 3, - dilation_base: int = 2, - causal: bool = False, - pad_mode: str = "reflect", - true_skip: bool = False, - compress: int = 2, - lstm: int = 2, - ): - super().__init__() - self.channels = channels - self.dimension = dimension - self.n_filters = n_filters - self.ratios = list(reversed(ratios)) - del ratios - self.n_residual_layers = n_residual_layers - self.hop_length = np.prod(self.ratios) # 计算乘积 - - act = getattr(nn, activation) - mult = 1 - model: List[nn.Module] = [ - SConv1d( - channels, - mult * n_filters, - kernel_size, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - pad_mode=pad_mode, - ) - ] - # Downsample to raw audio scale - for i, ratio in enumerate(self.ratios): - # Add residual layers - for j in range(n_residual_layers): - model += [ - SEANetResnetBlock( - mult * n_filters, - kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base**j, 1], - norm=norm, - norm_params=norm_params, - activation=activation, - activation_params=activation_params, - causal=causal, - pad_mode=pad_mode, - compress=compress, - true_skip=true_skip, - ) - ] - - # Add downsampling layers - model += [ - act(**activation_params), - SConv1d( - mult * n_filters, - mult * n_filters * 2, - kernel_size=ratio * 2, - stride=ratio, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - pad_mode=pad_mode, - ), - ] - mult *= 2 - - if lstm: - model += [SLSTM(mult * n_filters, num_layers=lstm)] - - model += [ - act(**activation_params), - SConv1d( - mult * n_filters, - dimension, - last_kernel_size, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - pad_mode=pad_mode, - ), - ] - - self.model = nn.Sequential(*model) - - def forward(self, x): - return self.model(x) - - -class SEANetDecoder(nn.Module): - """SEANet decoder. - Args: - channels (int): Audio channels. - dimension (int): Intermediate representation dimension. - n_filters (int): Base width for the model. - n_residual_layers (int): nb of residual layers. - ratios (Sequence[int]): kernel size and stride ratios - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function - final_activation (str): Final activation function after all convolutions. - final_activation_params (dict): Parameters to provide to the activation function - norm (str): Normalization method. - norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. - kernel_size (int): Kernel size for the initial convolution. - last_kernel_size (int): Kernel size for the initial convolution. - residual_kernel_size (int): Kernel size for the residual layers. - dilation_base (int): How much to increase the dilation with each layer. - causal (bool): Whether to use fully causal convolution. - pad_mode (str): Padding mode for the convolutions. - true_skip (bool): Whether to use true skip connection or a simple - (streamable) convolution as the skip connection in the residual network blocks. - compress (int): Reduced dimensionality in residual branches (from Demucs v3). - lstm (int): Number of LSTM layers at the end of the encoder. - trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. - If equal to 1.0, it means that all the trimming is done at the right. - """ - - def __init__( - self, - channels: int = 1, - dimension: int = 128, - n_filters: int = 32, - n_residual_layers: int = 1, - ratios: List[int] = [8, 5, 4, 2], - activation: str = "ELU", - activation_params: dict = {"alpha": 1.0}, - final_activation: Optional[str] = None, - final_activation_params: Optional[dict] = None, - norm: str = "weight_norm", - norm_params: Dict[str, Any] = {}, - kernel_size: int = 7, - last_kernel_size: int = 7, - residual_kernel_size: int = 3, - dilation_base: int = 2, - causal: bool = False, - pad_mode: str = "reflect", - true_skip: bool = False, - compress: int = 2, - lstm: int = 2, - trim_right_ratio: float = 1.0, - ): - super().__init__() - self.dimension = dimension - self.channels = channels - self.n_filters = n_filters - self.ratios = ratios - del ratios - self.n_residual_layers = n_residual_layers - self.hop_length = np.prod(self.ratios) - - act = getattr(nn, activation) - mult = int(2 ** len(self.ratios)) - model: List[nn.Module] = [ - SConv1d( - dimension, - mult * n_filters, - kernel_size, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - pad_mode=pad_mode, - ) - ] - - if lstm: - model += [SLSTM(mult * n_filters, num_layers=lstm)] - - # Upsample to raw audio scale - for i, ratio in enumerate(self.ratios): - # Add upsampling layers - model += [ - act(**activation_params), - SConvTranspose1d( - mult * n_filters, - mult * n_filters // 2, - kernel_size=ratio * 2, - stride=ratio, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - trim_right_ratio=trim_right_ratio, - ), - ] - # Add residual layers - for j in range(n_residual_layers): - model += [ - SEANetResnetBlock( - mult * n_filters // 2, - kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base**j, 1], - activation=activation, - activation_params=activation_params, - norm=norm, - norm_params=norm_params, - causal=causal, - pad_mode=pad_mode, - compress=compress, - true_skip=true_skip, - ) - ] - - mult //= 2 - - # Add final layers - model += [ - act(**activation_params), - SConv1d( - n_filters, - channels, - last_kernel_size, - norm=norm, - norm_kwargs=norm_params, - causal=causal, - pad_mode=pad_mode, - ), - ] - # Add optional final activation to decoder (eg. tanh) - if final_activation is not None: - final_act = getattr(nn, final_activation) - final_activation_params = final_activation_params or {} - model += [final_act(**final_activation_params)] - self.model = nn.Sequential(*model) - - def forward(self, z): - y = self.model(z) - return y - - -def test(): - import torch - - encoder = SEANetEncoder() - decoder = SEANetDecoder() - x = torch.randn(1, 1, 24000) - z = encoder(x) - print("z ", z.shape) - assert 1 == 2 - assert list(z.shape) == [1, 128, 75], z.shape - y = decoder(z) - assert y.shape == x.shape, (x.shape, y.shape) - - -if __name__ == "__main__": - test() diff --git a/egs/libritts/CODEC/encodec/modules/transformer.py b/egs/libritts/CODEC/encodec/modules/transformer.py deleted file mode 100644 index 1768d88f9..000000000 --- a/egs/libritts/CODEC/encodec/modules/transformer.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""A streamable transformer.""" -import typing as tp -from typing import Any, List, Optional, Union - -import torch -import torch.nn.functional as F -from torch import Tensor, nn - - -def create_sin_embedding(positions: Tensor, dim: int, max_period: float = 10000): - """Create time embedding for the given positions, target dimension `dim`.""" - # We aim for BTC format - assert dim % 2 == 0 - half_dim = dim // 2 - adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) - phase = positions / (max_period ** (adim / (half_dim - 1))) - return torch.cat( - [ - torch.cos(phase), - torch.sin(phase), - ], - dim=-1, - ) - - -class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): - def forward(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore - if self.norm_first: - sa_input = self.norm1(x) - x = x + self._sa_block(sa_input, x_past, past_context) - x = x + self._ff_block(self.norm2(x)) - else: - sa_input = x - x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) - x = self.norm2(x + self._ff_block(x)) - - return x, sa_input - - # self-attention block - def _sa_block(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore - _, T, _ = x.shape - _, H, _ = x_past.shape - - queries = x - keys = torch.cat([x_past, x], dim=1) - values = keys - - queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) - keys_pos = torch.arange(T + H, device=x.device).view(1, -1) - delta = queries_pos - keys_pos - valid_access = (delta >= 0) & (delta <= past_context) - x = self.self_attn( - queries, keys, values, attn_mask=~valid_access, need_weights=False - )[0] - return self.dropout1(x) - - -class StreamingTransformerEncoder(nn.Module): - """TransformerEncoder with streaming support. - - Args: - dim (int): dimension of the data. - hidden_scale (int): intermediate dimension of FF module is this times the dimension. - num_heads (int): number of heads. - num_layers (int): number of layers. - max_period (float): maxium period of cosines in the positional embedding. - past_context (int or None): receptive field for the causal mask, infinite if None. - gelu (bool): if true uses GeLUs, otherwise use ReLUs. - norm_in (bool): normalize the input. - dropout (float): dropout probability. - **kwargs: See `nn.TransformerEncoderLayer`. - """ - - def __init__( - self, - dim, - hidden_scale: float = 4.0, - num_heads: int = 8, - num_layers: int = 5, - max_period: float = 10000, - past_context: int = 1000, - gelu: bool = True, - norm_in: bool = True, - dropout: float = 0.0, - **kwargs - ): - super().__init__() - assert dim % num_heads == 0 - hidden_dim = int(dim * hidden_scale) - - self.max_period = max_period - self.past_context = past_context - activation: Any = F.gelu if gelu else F.relu - - self.norm_in: nn.Module - if norm_in: - self.norm_in = nn.LayerNorm(dim) - else: - self.norm_in = nn.Identity() - - self.layers = nn.ModuleList() - for idx in range(num_layers): - self.layers.append( - StreamingTransformerEncoderLayer( - dim, - num_heads, - hidden_dim, - activation=activation, - batch_first=True, - dropout=dropout, - **kwargs - ) - ) - - def forward( - self, - x: Tensor, - states: Optional[List[Tensor]] = None, - offset: Union[int, Tensor] = 0, - ): - B, T, C = x.shape - if states is None: - states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] - - positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset - pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) - - new_state: List[Tensor] = [] - x = self.norm_in(x) - x = x + pos_emb - - for layer_state, layer in zip(states, self.layers): - x, new_layer_state = layer(x, layer_state, self.past_context) - new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) - new_state.append(new_layer_state[:, -self.past_context :, :]) - return x, new_state, offset + T diff --git a/egs/libritts/CODEC/encodec/quantization/__init__.py b/egs/libritts/CODEC/encodec/quantization/__init__.py deleted file mode 100644 index 82d744f5f..000000000 --- a/egs/libritts/CODEC/encodec/quantization/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -# flake8: noqa -from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/egs/libritts/CODEC/encodec/quantization/ac.py b/egs/libritts/CODEC/encodec/quantization/ac.py deleted file mode 100644 index 8d8a770ca..000000000 --- a/egs/libritts/CODEC/encodec/quantization/ac.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Arithmetic coder.""" -import io -import math -import random -from typing import IO, Any, List, Optional - -import torch -from torch import Tensor - -from ..binary import BitPacker, BitUnpacker - - -def build_stable_quantized_cdf( - pdf: Tensor, - total_range_bits: int, - roundoff: float = 1e-8, - min_range: int = 2, - check: bool = True, -) -> Tensor: - """Turn the given PDF into a quantized CDF that splits - [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional - to the PDF. - - Args: - pdf (Tensor): probability distribution, shape should be `[N]`. - total_range_bits (int): see `ArithmeticCoder`, the typical range we expect - during the coding process is `[0, 2 ** total_range_bits - 1]`. - roundoff (float): will round the pdf up to that level to remove difference coming - from e.g. evaluating the Language Model on different architectures. - min_range (int): minimum range width. Should always be at least 2 for numerical - stability. Use this to avoid pathological behavior is a value - that is expected to be rare actually happens in real life. - check (bool): if True, checks that nothing bad happened, can be deactivated for speed. - """ - pdf = pdf.detach() - if roundoff: - pdf = (pdf / roundoff).floor() * roundoff - # interpolate with uniform distribution to achieve desired minimum probability. - total_range = 2**total_range_bits - cardinality = len(pdf) - alpha = min_range * cardinality / total_range - assert alpha <= 1, "you must reduce min_range" - ranges = (((1 - alpha) * total_range) * pdf).floor().long() - ranges += min_range - quantized_cdf = torch.cumsum(ranges, dim=-1) - if min_range < 2: - raise ValueError("min_range must be at least 2.") - if check: - assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] - if ( - (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range - ).any() or quantized_cdf[0] < min_range: - raise ValueError("You must increase your total_range_bits.") - return quantized_cdf - - -class ArithmeticCoder: - """ArithmeticCoder, - Let us take a distribution `p` over `N` symbols, and assume we have a stream - of random variables `s_t` sampled from `p`. Let us assume that we have a budget - of `B` bits that we can afford to write on device. There are `2**B` possible numbers, - corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single - sequence `(s_t)` by doing the following: - - 1) Initialize the current range to` [0 ** 2 B - 1]`. - 2) For each time step t, split the current range into contiguous chunks, - one for each possible outcome, with size roughly proportional to `p`. - For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks - would be `{[0, 2], [3, 3]}`. - 3) Select the chunk corresponding to `s_t`, and replace the current range with this. - 4) When done encoding all the values, just select any value remaining in the range. - - You will notice that this procedure can fail: for instance if at any point in time - the range is smaller than `N`, then we can no longer assign a non-empty chunk to each - possible outcome. Intuitively, the more likely a value is, the less the range width - will reduce, and the longer we can go on encoding values. This makes sense: for any efficient - coding scheme, likely outcomes would take less bits, and more of them can be coded - with a fixed budget. - - In practice, we do not know `B` ahead of time, but we have a way to inject new bits - when the current range decreases below a given limit (given by `total_range_bits`), without - having to redo all the computations. If we encode mostly likely values, we will seldom - need to inject new bits, but a single rare value can deplete our stock of entropy! - - In this explanation, we assumed that the distribution `p` was constant. In fact, the present - code works for any sequence `(p_t)` possibly different for each timestep. - We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller - the KL between the true distribution and `p_t`, the most efficient the coding will be. - - Args: - fo (IO[bytes]): file-like object to which the bytes will be written to. - total_range_bits (int): the range `M` described above is `2 ** total_range_bits. - Any time the current range width fall under this limit, new bits will - be injected to rescale the initial range. - """ - - def __init__(self, fo: IO[bytes], total_range_bits: int = 24): - assert total_range_bits <= 30 - self.total_range_bits = total_range_bits - self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. - self.low: int = 0 - self.high: int = 0 - self.max_bit: int = -1 - self._dbg: List[Any] = [] - self._dbg2: List[Any] = [] - - @property - def delta(self) -> int: - """Return the current range width.""" - return self.high - self.low + 1 - - def _flush_common_prefix(self): - # If self.low and self.high start with the sames bits, - # those won't change anymore as we always just increase the range - # by powers of 2, and we can flush them out to the bit stream. - assert self.high >= self.low, (self.low, self.high) - assert self.high < 2 ** (self.max_bit + 1) - while self.max_bit >= 0: - b1 = self.low >> self.max_bit - b2 = self.high >> self.max_bit - if b1 == b2: - self.low -= b1 << self.max_bit - self.high -= b1 << self.max_bit - assert self.high >= self.low, (self.high, self.low, self.max_bit) - assert self.low >= 0 - self.max_bit -= 1 - self.packer.push(b1) - else: - break - - def push(self, symbol: int, quantized_cdf: Tensor): - """Push the given symbol on the stream, flushing out bits - if possible. - - Args: - symbol (int): symbol to encode with the AC. - quantized_cdf (Tensor): use `build_stable_quantized_cdf` - to build this from your pdf estimate. - """ - while self.delta < 2**self.total_range_bits: - self.low *= 2 - self.high = self.high * 2 + 1 - self.max_bit += 1 - - range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() - range_high = quantized_cdf[symbol].item() - 1 - effective_low = int( - math.ceil(range_low * (self.delta / (2**self.total_range_bits))) - ) - effective_high = int( - math.floor(range_high * (self.delta / (2**self.total_range_bits))) - ) - assert self.low <= self.high - self.high = self.low + effective_high - self.low = self.low + effective_low - assert self.low <= self.high, ( - effective_low, - effective_high, - range_low, - range_high, - ) - self._dbg.append((self.low, self.high)) - self._dbg2.append((self.low, self.high)) - outs = self._flush_common_prefix() - assert self.low <= self.high - assert self.max_bit >= -1 - assert self.max_bit <= 61, self.max_bit - return outs - - def flush(self): - """Flush the remaining information to the stream.""" - while self.max_bit >= 0: - b1 = (self.low >> self.max_bit) & 1 - self.packer.push(b1) - self.max_bit -= 1 - self.packer.flush() - - -class ArithmeticDecoder: - """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. - - Note that this must be called with **exactly** the same parameters and sequence - of quantized cdf as the arithmetic encoder or the wrong values will be decoded. - - If the AC encoder current range is [L, H], with `L` and `H` having the some common - prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. - For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside - `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained - for a specific sequence of symbols and a binary-search allows us to decode those symbols. - At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, - and we will need to read new bits from the stream and repeat the process. - - """ - - def __init__(self, fo: IO[bytes], total_range_bits: int = 24): - self.total_range_bits = total_range_bits - self.low: int = 0 - self.high: int = 0 - self.current: int = 0 - self.max_bit: int = -1 - self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. - # Following is for debugging - self._dbg: List[Any] = [] - self._dbg2: List[Any] = [] - self._last: Any = None - - @property - def delta(self) -> int: - return self.high - self.low + 1 - - def _flush_common_prefix(self): - # Given the current range [L, H], if both have a common prefix, - # we know we can remove it from our representation to avoid handling large numbers. - while self.max_bit >= 0: - b1 = self.low >> self.max_bit - b2 = self.high >> self.max_bit - if b1 == b2: - self.low -= b1 << self.max_bit - self.high -= b1 << self.max_bit - self.current -= b1 << self.max_bit - assert self.high >= self.low - assert self.low >= 0 - self.max_bit -= 1 - else: - break - - def pull(self, quantized_cdf: Tensor) -> Optional[int]: - """Pull a symbol, reading as many bits from the stream as required. - This returns `None` when the stream has been exhausted. - - Args: - quantized_cdf (Tensor): use `build_stable_quantized_cdf` - to build this from your pdf estimate. This must be **exatly** - the same cdf as the one used at encoding time. - """ - while self.delta < 2**self.total_range_bits: - bit = self.unpacker.pull() - if bit is None: - return None - self.low *= 2 - self.high = self.high * 2 + 1 - self.current = self.current * 2 + bit - self.max_bit += 1 - - def bin_search(low_idx: int, high_idx: int): - # Binary search is not just for coding interviews :) - if high_idx < low_idx: - raise RuntimeError("Binary search failed") - mid = (low_idx + high_idx) // 2 - range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 - range_high = quantized_cdf[mid].item() - 1 - effective_low = int( - math.ceil(range_low * (self.delta / (2**self.total_range_bits))) - ) - effective_high = int( - math.floor(range_high * (self.delta / (2**self.total_range_bits))) - ) - low = effective_low + self.low - high = effective_high + self.low - if self.current >= low: - if self.current <= high: - return (mid, low, high, self.current) - else: - return bin_search(mid + 1, high_idx) - else: - return bin_search(low_idx, mid - 1) - - self._last = (self.low, self.high, self.current, self.max_bit) - sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) - self._dbg.append((self.low, self.high, self.current)) - self._flush_common_prefix() - self._dbg2.append((self.low, self.high, self.current)) - - return sym - - -def test(): - torch.manual_seed(1234) - random.seed(1234) - for _ in range(4): - pdfs = [] - cardinality = random.randrange(4000) - steps = random.randrange(100, 500) - fo = io.BytesIO() - encoder = ArithmeticCoder(fo) - symbols = [] - for step in range(steps): - pdf = torch.softmax(torch.randn(cardinality), dim=0) - pdfs.append(pdf) - q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) - symbol = torch.multinomial(pdf, 1).item() - symbols.append(symbol) - encoder.push(symbol, q_cdf) - encoder.flush() - - fo.seek(0) - decoder = ArithmeticDecoder(fo) - for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): - q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) - decoded_symbol = decoder.pull(q_cdf) - assert decoded_symbol == symbol, idx - assert decoder.pull(torch.zeros(1)) is None - - -if __name__ == "__main__": - test() diff --git a/egs/libritts/CODEC/encodec/quantization/core_vq.py b/egs/libritts/CODEC/encodec/quantization/core_vq.py deleted file mode 100644 index 4719e20f7..000000000 --- a/egs/libritts/CODEC/encodec/quantization/core_vq.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# This implementation is inspired from -# https://github.com/lucidrains/vector-quantize-pytorch -# which is released under MIT License. Hereafter, the original license: -# MIT License -# -# Copyright (c) 2020 Phil Wang -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -"""Core vector quantization implementation.""" - -from typing import Any, Callable, Optional, Union - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import nn - -from .distrib import broadcast_tensors - - -def default(val: Any, d: Any) -> Any: - return val if val is not None else d - - -def ema_inplace(moving_avg, new, decay: float): - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): - return (x + epsilon) / (x.sum() + n_categories * epsilon) - - -def uniform_init(*shape: int): - t = torch.empty(shape) - nn.init.kaiming_uniform_(t) - return t - - -def sample_vectors(samples, num: int): - num_samples, device = samples.shape[0], samples.device - - if num_samples >= num: - indices = torch.randperm(num_samples, device=device)[:num] - else: - indices = torch.randint(0, num_samples, (num,), device=device) - - return samples[indices] - - -def kmeans(samples, num_clusters: int, num_iters: int = 10): - dim, dtype = samples.shape[-1], samples.dtype - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") - dists = -(diffs**2).sum(dim=-1) - - buckets = dists.max(dim=-1).indices - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] - - means = torch.where(zero_mask[..., None], means, new_means) - - return means, bins - - -class EuclideanCodebook(nn.Module): - """Codebook with Euclidean distance. - Args: - dim (int): Dimension. - codebook_size (int): Codebook size. - kmeans_init (bool): Whether to use k-means to initialize the codebooks. - If set to true, run the k-means algorithm on the first training batch and use - the learned centroids as initialization. - kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: int = False, - kmeans_iters: int = 10, - decay: float = 0.99, - epsilon: float = 1e-5, - threshold_ema_dead_code: int = 2, - ): - super().__init__() - self.decay = decay - init_fn: Union[Callable[..., torch.Tensor], Any] = ( - uniform_init if not kmeans_init else torch.zeros - ) - embed = init_fn(codebook_size, dim) - - self.codebook_size = codebook_size - - self.kmeans_iters = kmeans_iters - self.epsilon = epsilon - self.threshold_ema_dead_code = threshold_ema_dead_code - - self.register_buffer("inited", torch.Tensor([not kmeans_init])) - self.register_buffer("cluster_size", torch.zeros(codebook_size)) - self.register_buffer("embed", embed) - self.register_buffer("embed_avg", embed.clone()) - - @torch.jit.ignore - def init_embed_(self, data): - if self.inited: - return - - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embed.data.copy_(embed) - self.embed_avg.data.copy_(embed.clone()) - self.cluster_size.data.copy_(cluster_size) - self.inited.data.copy_(torch.Tensor([True])) - # Make sure all buffers across workers are in sync after initialization - broadcast_tensors(self.buffers()) - - def replace_(self, samples, mask): - modified_codebook = torch.where( - mask[..., None], sample_vectors(samples, self.codebook_size), self.embed - ) - self.embed.data.copy_(modified_codebook) - - def expire_codes_(self, batch_samples): - if self.threshold_ema_dead_code == 0: - return - - expired_codes = self.cluster_size < self.threshold_ema_dead_code - if not torch.any(expired_codes): - return - - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - broadcast_tensors(self.buffers()) - - def preprocess(self, x): - x = rearrange(x, "... d -> (...) d") - return x - - def quantize(self, x): - embed = self.embed.t() - dist = -( - x.pow(2).sum(1, keepdim=True) - - 2 * x @ embed - + embed.pow(2).sum(0, keepdim=True) - ) - embed_ind = dist.max(dim=-1).indices - return embed_ind - - def postprocess_emb(self, embed_ind, shape): - return embed_ind.view(*shape[:-1]) - - def dequantize(self, embed_ind): - quantize = F.embedding(embed_ind, self.embed) - return quantize - - def encode(self, x): - shape = x.shape - # pre-process - x = self.preprocess(x) - # quantize - embed_ind = self.quantize(x) - # post-process - embed_ind = self.postprocess_emb(embed_ind, shape) - return embed_ind - - def decode(self, embed_ind): - quantize = self.dequantize(embed_ind) - return quantize - - def forward(self, x): - shape, dtype = x.shape, x.dtype - x = self.preprocess(x) - - self.init_embed_(x) - - embed_ind = self.quantize(x) - embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) - embed_ind = self.postprocess_emb(embed_ind, shape) - quantize = self.dequantize(embed_ind) - - if self.training: - # We do the expiry of code at that point as buffers are in sync - # and all the workers will take the same decision. - self.expire_codes_(x) - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) - embed_sum = x.t() @ embed_onehot - ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) - * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - self.embed.data.copy_(embed_normalized) - - return quantize, embed_ind - - -class VectorQuantization(nn.Module): - """Vector quantization implementation. - Currently supports only euclidean distance. - Args: - dim (int): Dimension - codebook_size (int): Codebook size - codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - commitment_weight (float): Weight for commitment loss. - """ - - def __init__( - self, - dim: int, - codebook_size: int, - codebook_dim: Optional[int] = None, - decay: float = 0.99, - epsilon: float = 1e-5, - kmeans_init: bool = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - commitment_weight: float = 1.0, - ): - super().__init__() - _codebook_dim: int = default(codebook_dim, dim) - - requires_projection = _codebook_dim != dim - self.project_in = ( - nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() - ) - self.project_out = ( - nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() - ) - - self.epsilon = epsilon - self.commitment_weight = commitment_weight - - self._codebook = EuclideanCodebook( - dim=_codebook_dim, - codebook_size=codebook_size, - kmeans_init=kmeans_init, - kmeans_iters=kmeans_iters, - decay=decay, - epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code, - ) - self.codebook_size = codebook_size - - @property - def codebook(self): - return self._codebook.embed - - def encode(self, x): - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - embed_in = self._codebook.encode(x) - return embed_in - - def decode(self, embed_ind): - quantize = self._codebook.decode(embed_ind) - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize - - def forward(self, x): - device = x.device - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - - quantize, embed_ind = self._codebook(x) - - if self.training: - quantize = x + (quantize - x).detach() - - loss = torch.tensor([0.0], device=device, requires_grad=self.training) - - if self.training: - if self.commitment_weight > 0: - commit_loss = F.mse_loss(quantize.detach(), x) - loss = loss + commit_loss * self.commitment_weight - - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize, embed_ind, loss - - -class ResidualVectorQuantization(nn.Module): - """Residual vector quantization implementation. - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf - """ - - def __init__(self, *, num_quantizers, **kwargs): - super().__init__() - self.layers = nn.ModuleList( - [VectorQuantization(**kwargs) for _ in range(num_quantizers)] - ) - - def forward(self, x, n_q: Optional[int] = None): - quantized_out = 0.0 - residual = x - - all_losses = [] - all_indices = [] - - n_q = n_q or len(self.layers) - - for layer in self.layers[:n_q]: - quantized, indices, loss = layer(residual) - residual = residual - quantized - quantized_out = quantized_out + quantized - - all_indices.append(indices) - all_losses.append(loss) - - out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) - return quantized_out, out_indices, out_losses - - def encode( - self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None - ) -> torch.Tensor: - residual = x - all_indices = [] - n_q = n_q or len(self.layers) - st = st or 0 - for layer in self.layers[st:n_q]: - indices = layer.encode(residual) - quantized = layer.decode(indices) - residual = residual - quantized - all_indices.append(indices) - out_indices = torch.stack(all_indices) - return out_indices - - def decode(self, q_indices: torch.Tensor) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized - return quantized_out diff --git a/egs/libritts/CODEC/encodec/quantization/distrib.py b/egs/libritts/CODEC/encodec/quantization/distrib.py deleted file mode 100644 index 41ac7525f..000000000 --- a/egs/libritts/CODEC/encodec/quantization/distrib.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Torch distributed utilities.""" -from typing import Dict, Iterable, List - -import torch -from torch import distributed as dist - - -def rank(): - if dist.is_initialized(): - return dist.get_rank() - else: - return 0 - - -def world_size(): - if dist.is_initialized(): - return dist.get_world_size() - else: - return 1 - - -def is_distributed(): - return world_size() > 1 - - -def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM): - if is_distributed(): - return dist.all_reduce(tensor, op) - - -def _is_complex_or_float(tensor): - return torch.is_floating_point(tensor) or torch.is_complex(tensor) - - -def _check_number_of_params(params: List[torch.Tensor]): - # utility function to check that the number of params in all workers is the same, - # and thus avoid a deadlock with distributed all reduce. - if not is_distributed() or not params: - return - # print('params[0].device ', params[0].device) - tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) - all_reduce(tensor) - if tensor.item() != len(params) * world_size(): - # If not all the workers have the same number, for at least one of them, - # this inequality will be verified. - raise RuntimeError( - f"Mismatch in number of params: ours is {len(params)}, " - "at least one worker has a different one." - ) - - -def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0): - """Broadcast the tensors from the given parameters to all workers. - This can be used to ensure that all workers have the same model to start with. - """ - if not is_distributed(): - return - tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] - _check_number_of_params(tensors) - handles = [] - for tensor in tensors: - # src = int(rank()) # added code - handle = dist.broadcast(tensor.data, src=src, async_op=True) - handles.append(handle) - for handle in handles: - handle.wait() - - -def sync_buffer(buffers, average=True): - """ - Sync grad for buffers. If average is False, broadcast instead of averaging. - """ - if not is_distributed(): - return - handles = [] - for buffer in buffers: - if torch.is_floating_point(buffer.data): - if average: - handle = dist.all_reduce( - buffer.data, op=dist.ReduceOp.SUM, async_op=True - ) - else: - handle = dist.broadcast(buffer.data, src=0, async_op=True) - handles.append((buffer, handle)) - for buffer, handle in handles: - handle.wait() - if average: - buffer.data /= world_size - - -def sync_grad(params): - """ - Simpler alternative to DistributedDataParallel, that doesn't rely - on any black magic. For simple models it can also be as fast. - Just call this on your model parameters after the call to backward! - """ - if not is_distributed(): - return - handles = [] - for p in params: - if p.grad is not None: - handle = dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM, async_op=True) - handles.append((p, handle)) - for p, handle in handles: - handle.wait() - p.grad.data /= world_size() - - -def average_metrics(metrics: Dict[str, float], count=1.0): - """Average a dictionary of metrics across all workers, using the optional - `count` as unormalized weight. - """ - if not is_distributed(): - return metrics - keys, values = zip(*metrics.items()) - device = "cuda" if torch.cuda.is_available() else "cpu" - tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) - tensor *= count - all_reduce(tensor) - averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() - return dict(zip(keys, averaged)) diff --git a/egs/libritts/CODEC/encodec/quantization/vq.py b/egs/libritts/CODEC/encodec/quantization/vq.py deleted file mode 100644 index 8e59887a6..000000000 --- a/egs/libritts/CODEC/encodec/quantization/vq.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE -"""Residual vector quantizer implementation.""" -import math -from dataclasses import dataclass, field -from typing import Optional - -import torch -from torch import Tensor, nn - -from .core_vq import ResidualVectorQuantization - - -@dataclass -class QuantizedResult: - quantized: Tensor - codes: Tensor - bandwidth: Tensor # bandwidth in kb/s used, per batch item. - penalty: Optional[Tensor] = None - metrics: dict = field(default_factory=dict) - - -class ResidualVectorQuantizer(nn.Module): - """Residual Vector Quantizer. - Args: - dimension (int): Dimension of the codebooks. - n_q (int): Number of residual vector quantizers used. - bins (int): Codebook size. - decay (float): Decay for exponential moving average over the codebooks. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - - def __init__( - self, - dimension: int = 256, - n_q: int = 8, - bins: int = 1024, - decay: float = 0.99, - kmeans_init: bool = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - ): - super().__init__() - self.n_q = n_q - self.dimension = dimension - self.bins = bins - self.decay = decay - self.kmeans_init = kmeans_init - self.kmeans_iters = kmeans_iters - self.threshold_ema_dead_code = threshold_ema_dead_code - self.vq = ResidualVectorQuantization( - dim=self.dimension, - codebook_size=self.bins, - num_quantizers=self.n_q, - decay=self.decay, - kmeans_init=self.kmeans_init, - kmeans_iters=self.kmeans_iters, - threshold_ema_dead_code=self.threshold_ema_dead_code, - ) - - def forward( - self, x: Tensor, sample_rate: int, bandwidth: Optional[float] = None - ) -> QuantizedResult: - """Residual vector quantization on the given input tensor. - Args: - x (Tensor): Input tensor. - sample_rate (int): Sample rate of the input tensor. - bandwidth (float): Target bandwidth. - Returns: - QuantizedResult: - The quantized (or approximately quantized) representation with - the associated bandwidth and any penalty term for the loss. - """ - bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) - n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) - quantized, codes, commit_loss = self.vq(x, n_q=n_q) - bw = torch.tensor(n_q * bw_per_q).to(x) - return quantized, codes, bw, torch.mean(commit_loss) - # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) - - def get_num_quantizers_for_bandwidth( - self, sample_rate: int, bandwidth: Optional[float] = None - ) -> int: - """Return n_q based on specified target bandwidth.""" - bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) - n_q = self.n_q - if bandwidth and bandwidth > 0.0: - n_q = int(max(1, math.floor(bandwidth / bw_per_q))) - return n_q - - def get_bandwidth_per_quantizer(self, sample_rate: int): - """Return bandwidth per quantizer for a given input sample rate.""" - return math.log2(self.bins) * sample_rate / 1000 - - def encode( - self, - x: Tensor, - sample_rate: int, - bandwidth: Optional[float] = None, - st: Optional[int] = None, - ) -> Tensor: - """Encode a given input tensor with the specified sample rate at the given bandwidth. - The RVQ encode method sets the appropriate number of quantizer to use - and returns indices for each quantizer. - """ - n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) - st = st or 0 - codes = self.vq.encode(x, n_q=n_q, st=st) - return codes - - def decode(self, codes: Tensor) -> Tensor: - """Decode the given codes to the quantized representation.""" - quantized = self.vq.decode(codes) - return quantized diff --git a/egs/libritts/CODEC/encodec/scheduler.py b/egs/libritts/CODEC/encodec/scheduler.py deleted file mode 100644 index 00ef9882a..000000000 --- a/egs/libritts/CODEC/encodec/scheduler.py +++ /dev/null @@ -1,171 +0,0 @@ -# original implementation is from https://github.com/ZhikangNiu/encodec-pytorch/blob/main/scheduler.py - -# Copyright 2024 Zhi-Kang Niu -# MIT License - -import math -from bisect import bisect_right - -from torch.optim.lr_scheduler import _LRScheduler - - -# It will be replaced with huggingface optimization -class WarmUpLR(_LRScheduler): - """warmup_training learning rate scheduler - Args: - optimizer: optimzier(e.g. SGD) - total_iters: totoal_iters of warmup phase - """ - - def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1): - - self.total_iters = iter_per_epoch * warmup_epoch - self.iter_per_epoch = iter_per_epoch - super().__init__(optimizer, last_epoch) - - def get_lr(self): - """we will use the first m batches, and set the learning - rate to base_lr * m / total_iters - """ - return [ - base_lr * self.last_epoch / (self.total_iters + 1e-8) - for base_lr in self.base_lrs - ] - - -class WarmupLrScheduler(_LRScheduler): - def __init__( - self, - optimizer, - warmup_iter=500, - warmup_ratio=5e-4, - warmup="exp", - last_epoch=-1, - ): - self.warmup_iter = warmup_iter - self.warmup_ratio = warmup_ratio - self.warmup = warmup - super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) - - def get_lr(self): - ratio = self.get_lr_ratio() - lrs = [ratio * lr for lr in self.base_lrs] - return lrs - - def get_lr_ratio(self): - if self.last_epoch < self.warmup_iter: - ratio = self.get_warmup_ratio() - else: - ratio = self.get_main_ratio() - return ratio - - def get_main_ratio(self): - raise NotImplementedError - - def get_warmup_ratio(self): - assert self.warmup in ("linear", "exp") - alpha = self.last_epoch / self.warmup_iter - if self.warmup == "linear": - ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha - elif self.warmup == "exp": - ratio = self.warmup_ratio ** (1.0 - alpha) - return ratio - - -class WarmupPolyLrScheduler(WarmupLrScheduler): - def __init__( - self, - optimizer, - power, - max_iter, - warmup_iter=500, - warmup_ratio=5e-4, - warmup="exp", - last_epoch=-1, - ): - self.power = power - self.max_iter = max_iter - super(WarmupPolyLrScheduler, self).__init__( - optimizer, warmup_iter, warmup_ratio, warmup, last_epoch - ) - - def get_main_ratio(self): - real_iter = self.last_epoch - self.warmup_iter - real_max_iter = self.max_iter - self.warmup_iter - alpha = real_iter / real_max_iter - ratio = (1 - alpha) ** self.power - return ratio - - -class WarmupExpLrScheduler(WarmupLrScheduler): - def __init__( - self, - optimizer, - gamma, - interval=1, - warmup_iter=500, - warmup_ratio=5e-4, - warmup="exp", - last_epoch=-1, - ): - self.gamma = gamma - self.interval = interval - super(WarmupExpLrScheduler, self).__init__( - optimizer, warmup_iter, warmup_ratio, warmup, last_epoch - ) - - def get_main_ratio(self): - real_iter = self.last_epoch - self.warmup_iter - ratio = self.gamma ** (real_iter // self.interval) - return ratio - - -class WarmupCosineLrScheduler(WarmupLrScheduler): - def __init__( - self, - optimizer, - max_iter, - eta_ratio=0, - warmup_iter=500, - warmup_ratio=5e-4, - warmup="exp", - last_epoch=-1, - ): - self.eta_ratio = eta_ratio - self.max_iter = max_iter - super(WarmupCosineLrScheduler, self).__init__( - optimizer, warmup_iter, warmup_ratio, warmup, last_epoch - ) - - def get_main_ratio(self): - real_iter = self.last_epoch - self.warmup_iter - real_max_iter = self.max_iter - self.warmup_iter - return ( - self.eta_ratio - + (1 - self.eta_ratio) - * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) - / 2 - ) - - -class WarmupStepLrScheduler(WarmupLrScheduler): - def __init__( - self, - optimizer, - milestones: list, - gamma=0.1, - warmup_iter=500, - warmup_ratio=5e-4, - warmup="exp", - last_epoch=-1, - ): - self.milestones = milestones - self.gamma = gamma - super(WarmupStepLrScheduler, self).__init__( - optimizer, warmup_iter, warmup_ratio, warmup, last_epoch - ) - - def get_main_ratio(self): - real_iter = self.last_epoch - self.warmup_iter - ratio = self.gamma ** bisect_right(self.milestones, real_iter) - return ratio diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py deleted file mode 100755 index a4f2eb7ab..000000000 --- a/egs/libritts/CODEC/encodec/train.py +++ /dev/null @@ -1,1188 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (Author: Zengwei Yao) -# 2024 The Chinese University of HK (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import itertools -import logging -import math -import random -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from codec_datamodule import LibriTTSCodecDataModule -from encodec import Encodec -from lhotse.utils import fix_random_seed -from scheduler import WarmupCosineLrScheduler -from torch import nn -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from utils import MetricsTracker, save_checkpoint - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, setup_logger, str2bool - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-samples", - type=int, - default=3, - help="Number of samples to generate for tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=500, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="encodec/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lr", type=float, default=3.0e-4, help="The base learning rate." - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=5, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - """ - params = AttributeDict( - { - # training params - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 200, - "env_info": get_env_info(), - "sampling_rate": 24000, - "audio_normalization": False, - "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss - "lambda_wav": 0.1, # loss scaling coefficient for waveform loss - "lambda_feat": 4.0, # loss scaling coefficient for feat loss - "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 1000.0, # loss scaling coefficient for commitment loss - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def get_model(params: AttributeDict) -> nn.Module: - """Get the model based on the configuration.""" - - from discriminators import ( - MultiPeriodDiscriminator, - MultiScaleDiscriminator, - MultiScaleSTFTDiscriminator, - ) - from modules.seanet import SEANetDecoder, SEANetEncoder - from quantization import ResidualVectorQuantizer - - # generator_params = { - # "generator_n_filters": 32, - # "dimension": 512, - # "ratios": [2, 2, 2, 4], - # "target_bandwidths": [7.5, 15], - # "bins": 1024, - # } - # discriminator_params = { - # "stft_discriminator_n_filters": 32, - # "discriminator_epoch_start": 5, - # } - # inference_params = { - # "target_bw": 7.5, - # } - - generator_params = { - "generator_n_filters": 32, - "dimension": 512, - "ratios": [8, 5, 4, 2], - "target_bandwidths": [1.5, 3, 6, 12, 24], - "bins": 1024, - } - discriminator_params = { - "stft_discriminator_n_filters": 32, - "discriminator_epoch_start": 5, - "n_ffts": [1024, 2048, 512], - "hop_lengths": [256, 512, 128], - "win_lengths": [1024, 2048, 512], - } - inference_params = { - "target_bw": 6, - } - - params.update(generator_params) - params.update(discriminator_params) - params.update(inference_params) - - hop_length = np.prod(params.ratios) - n_q = int( - 1000 - * params.target_bandwidths[-1] - // (math.ceil(params.sampling_rate / hop_length) * 10) - ) - - encoder = SEANetEncoder( - n_filters=params.generator_n_filters, - dimension=params.dimension, - ratios=params.ratios, - ) - decoder = SEANetDecoder( - n_filters=params.generator_n_filters, - dimension=params.dimension, - ratios=params.ratios, - ) - quantizer = ResidualVectorQuantizer( - dimension=params.dimension, n_q=n_q, bins=params.bins - ) - - model = Encodec( - params=params, - sampling_rate=params.sampling_rate, - target_bandwidths=params.target_bandwidths, - encoder=encoder, - quantizer=quantizer, - decoder=decoder, - multi_scale_discriminator=None, - multi_period_discriminator=None, - multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( - n_filters=params.stft_discriminator_n_filters, - n_ffts=params.n_ffts, - hop_lengths=params.hop_lengths, - win_lengths=params.win_lengths, - ), - ) - return model - - -def prepare_input( - params: AttributeDict, - batch: dict, - device: torch.device, - is_training: bool = True, -): - """Parse batch data""" - audio = batch["audio"].to(device, memory_format=torch.contiguous_format) - features = batch["features"].to(device, memory_format=torch.contiguous_format) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - - if is_training: - audio_dims = audio.size(-1) - start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate)) - audio = audio[:, start_idx : params.sampling_rate + start_idx] - else: - # NOTE(zengrui): a very coarse setup - audio = audio[ - :, params.sampling_rate : params.sampling_rate + params.sampling_rate - ] - - if params.audio_normalization: - mean = audio.mean(dim=-1, keepdim=True) - std = audio.std(dim=-1, keepdim=True) - audio = (audio - mean) / (std + 1e-7) - - return audio, audio_lens, features, features_lens - - -def train_discriminator(weight, global_step, threshold=0, value=0.0): - if global_step < threshold: - weight = value - return weight - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer_g: Optimizer, - optimizer_d: Optimizer, - scheduler_g: LRSchedulerType, - scheduler_d: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model to be trained. - optimizer_g: - The optimizer for generator. - optimizer_d: - The optimizer for discriminator. - scheduler_g: - The learning rate scheduler for generator, we call step() every epoch. - scheduler_d: - The learning rate scheduler for discriminator, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - - batch_size = len(batch["audio"]) - ( - audio, - audio_lens, - _, - _, - ) = prepare_input(params, batch, device) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - try: - with autocast(enabled=params.use_fp16): - d_weight = train_discriminator( - params.lambda_adv, - params.cur_epoch, - threshold=params.discriminator_epoch_start, - ) - # forward discriminator - ( - disc_stft_real_adv_loss, - disc_stft_fake_adv_loss, - disc_period_real_adv_loss, - disc_period_fake_adv_loss, - disc_scale_real_adv_loss, - disc_scale_fake_adv_loss, - stats_d, - ) = model( - speech=audio, - speech_lengths=audio_lens, - return_sample=False, - forward_generator=False, - ) - disc_loss = ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) * d_weight - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - # update discriminator - optimizer_d.zero_grad() - scaler.scale(disc_loss).backward() - scaler.step(optimizer_d) - - with autocast(enabled=params.use_fp16): - g_weight = train_discriminator( - params.lambda_adv, - params.cur_epoch, - threshold=params.discriminator_epoch_start, - ) - # forward generator - ( - commit_loss, - gen_stft_adv_loss, - gen_period_adv_loss, - gen_scale_adv_loss, - feature_stft_loss, - feature_period_loss, - feature_scale_loss, - wav_reconstruction_loss, - mel_reconstruction_loss, - stats_g, - ) = model( - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - return_sample=params.batch_idx_train % params.log_interval == 0, - ) - gen_adv_loss = ( - gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss - ) * g_weight - feature_loss = ( - feature_stft_loss + feature_period_loss + feature_scale_loss - ) - reconstruction_loss = ( - params.lambda_wav * wav_reconstruction_loss - + params.lambda_rec * mel_reconstruction_loss - ) - gen_loss = ( - gen_adv_loss - + reconstruction_loss - + params.lambda_feat * feature_loss - + params.lambda_com * commit_loss - ) - loss_info["generator_loss"] = gen_loss - for k, v in stats_g.items(): - if "returned_sample" not in k: - loss_info[k] = v * batch_size - # update generator - optimizer_g.zero_grad() - scaler.scale(gen_loss).backward() - scaler.step(optimizer_g) - scaler.update() - - # summary stats - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - # step per iteration - scheduler_g.step() - scheduler_d.step() - - if params.print_diagnostics and batch_idx == 5: - return - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr_g = max(scheduler_g.get_last_lr()) - cur_lr_d = max(scheduler_d.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate_g", cur_lr_g, params.batch_idx_train - ) - tb_writer.add_scalar( - "train/learning_rate_d", cur_lr_d, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - if "returned_sample" in stats_g: - # speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] - speech_hat_, speech_, _, _ = stats_g["returned_sample"] - - speech_hat_i = speech_hat_[0] - speech_i = speech_[0] - if speech_hat_i.dim() > 1: - speech_hat_i = speech_hat_i.squeeze(0) - speech_i = speech_i.squeeze(0) - tb_writer.add_audio( - f"train/speech_hat_", - speech_hat_i, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - f"train/speech_", - speech_i, - params.batch_idx_train, - params.sampling_rate, - ) - # tb_writer.add_image( - # "train/mel_hat_", - # plot_feature(mel_hat_), - # params.batch_idx_train, - # dataformats="HWC", - # ) - # tb_writer.add_image( - # "train/mel_", - # plot_feature(mel_), - # params.batch_idx_train, - # dataformats="HWC", - # ) - - if ( - params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info, (speech_hat, speech) = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - rank=rank, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None and rank == 0: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - for index in range(params.num_samples): # 3 - speech_hat_i = speech_hat[index] - speech_i = speech[index] - if speech_hat_i.dim() > 1: - speech_hat_i = speech_hat_i.squeeze(0) - speech_i = speech_i.squeeze(0) - tb_writer.add_audio( - f"train/valid_speech_hat_{index}", - speech_hat_i, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - f"train/valid_speech_{index}", - speech_i, - params.batch_idx_train, - params.sampling_rate, - ) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, - rank: int = 0, -) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - returned_sample = (None, None) - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["audio"]) - ( - audio, - audio_lens, - _, - _, - ) = prepare_input(params, batch, device, is_training=False) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - d_weight = train_discriminator( - params.lambda_adv, - params.cur_epoch, - threshold=params.discriminator_epoch_start, - ) - - # forward discriminator - ( - disc_stft_real_adv_loss, - disc_stft_fake_adv_loss, - disc_period_real_adv_loss, - disc_period_fake_adv_loss, - disc_scale_real_adv_loss, - disc_scale_fake_adv_loss, - stats_d, - ) = model( - speech=audio, - speech_lengths=audio_lens, - return_sample=False, - forward_generator=False, - ) - disc_loss = ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) * d_weight - assert disc_loss.requires_grad is False - loss_info["discriminator_loss"] = disc_loss - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - g_weight = train_discriminator( - params.lambda_adv, - params.cur_epoch, - threshold=params.discriminator_epoch_start, - ) - # forward generator - ( - commit_loss, - gen_stft_adv_loss, - gen_period_adv_loss, - gen_scale_adv_loss, - feature_stft_loss, - feature_period_loss, - feature_scale_loss, - wav_reconstruction_loss, - mel_reconstruction_loss, - stats_g, - ) = model( - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - return_sample=False, - ) - gen_adv_loss = ( - gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss - ) * g_weight - feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss - reconstruction_loss = ( - params.lambda_wav * wav_reconstruction_loss - + params.lambda_rec * mel_reconstruction_loss - ) - gen_loss = ( - gen_adv_loss - + reconstruction_loss - + params.lambda_feat * feature_loss - + params.lambda_com * commit_loss - ) - assert gen_loss.requires_grad is False - loss_info["generator_loss"] = gen_loss - for k, v in stats_g.items(): - if "returned_sample" not in k: - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - # infer for first batch: - if batch_idx == 0 and rank == 0: - inner_model = model.module if isinstance(model, DDP) else model - _, audio_hat = inner_model.inference( - x=audio, target_bw=params.target_bw - ) - returned_sample = (audio_hat, audio) - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss, returned_sample - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer_g: torch.optim.Optimizer, - optimizer_d: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - ( - audio, - audio_lens, - _, - _, - ) = prepare_input(params, batch, device) - try: - # for discriminator - with autocast(enabled=params.use_fp16): - ( - disc_stft_real_adv_loss, - disc_stft_fake_adv_loss, - disc_period_real_adv_loss, - disc_period_fake_adv_loss, - disc_scale_real_adv_loss, - disc_scale_fake_adv_loss, - stats_d, - ) = model( - speech=audio, - speech_lengths=audio_lens, - return_sample=False, - forward_generator=False, - ) - loss_d = ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) * train_discriminator( - params.lambda_adv, - params.cur_epoch, - threshold=params.discriminator_train_start, - ) - optimizer_d.zero_grad() - loss_d.backward() - # for generator - with autocast(enabled=params.use_fp16): - ( - commit_loss, - gen_stft_adv_loss, - gen_period_adv_loss, - gen_scale_adv_loss, - feature_stft_loss, - feature_period_loss, - feature_scale_loss, - wav_reconstruction_loss, - mel_reconstruction_loss, - stats_g, - ) = model( - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - return_sample=False, - ) - loss_g = ( - (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) - * train_discriminator( - params.lambda_adv, - 0, - threshold=params.discriminator_epoch_start, - ) - + ( - params.lambda_wav * wav_reconstruction_loss - + params.lambda_rec * mel_reconstruction_loss - ) - + params.lambda_feat - * (feature_stft_loss + feature_period_loss + feature_scale_loss) - + params.lambda_com * commit_loss - ) - optimizer_g.zero_grad() - loss_g.backward() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - libritts = LibriTTSCodecDataModule(args) - - if params.full_libri: - train_cuts = libritts.train_all_shuf_cuts() - else: - train_cuts = libritts.train_clean_100_cuts() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - encoder = model.encoder - decoder = model.decoder - quantizer = model.quantizer - multi_scale_discriminator = model.multi_scale_discriminator - multi_period_discriminator = model.multi_period_discriminator - multi_scale_stft_discriminator = model.multi_scale_stft_discriminator - - num_param_e = sum([p.numel() for p in encoder.parameters()]) - logging.info(f"Number of parameters in encoder: {num_param_e}") - num_param_d = sum([p.numel() for p in decoder.parameters()]) - logging.info(f"Number of parameters in decoder: {num_param_d}") - num_param_q = sum([p.numel() for p in quantizer.parameters()]) - logging.info(f"Number of parameters in quantizer: {num_param_q}") - num_param_ds = ( - sum([p.numel() for p in multi_scale_discriminator.parameters()]) - if multi_scale_discriminator is not None - else 0 - ) - logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") - num_param_dp = ( - sum([p.numel() for p in multi_period_discriminator.parameters()]) - if multi_period_discriminator is not None - else 0 - ) - logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") - num_param_dstft = sum( - [p.numel() for p in multi_scale_stft_discriminator.parameters()] - ) - logging.info( - f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" - ) - logging.info( - f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" - ) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = DDP( - model, - device_ids=[rank], - find_unused_parameters=True, - ) - - optimizer_g = torch.optim.AdamW( - itertools.chain( - encoder.parameters(), - quantizer.parameters(), - decoder.parameters(), - ), - lr=params.lr, - betas=(0.5, 0.9), - ) - discriminator_params = [ - multi_scale_stft_discriminator.parameters(), - ] - if multi_scale_discriminator is not None: - discriminator_params.append(multi_scale_discriminator.parameters()) - if multi_period_discriminator is not None: - discriminator_params.append(multi_period_discriminator.parameters()) - optimizer_d = torch.optim.AdamW( - itertools.chain(*discriminator_params), - lr=params.lr, - betas=(0.5, 0.9), - ) - - scheduler_g = WarmupCosineLrScheduler( - optimizer=optimizer_g, - max_iter=params.num_epochs * 1500, - eta_ratio=0.1, - warmup_iter=params.discriminator_epoch_start * 1500, - warmup_ratio=1e-4, - ) - scheduler_d = WarmupCosineLrScheduler( - optimizer=optimizer_d, - max_iter=params.num_epochs * 1500, - eta_ratio=0.1, - warmup_iter=params.discriminator_epoch_start * 1500, - warmup_ratio=1e-4, - ) - - if checkpoints is not None: - # load state_dict for optimizers - if "optimizer_g" in checkpoints: - logging.info("Loading optimizer_g state dict") - optimizer_g.load_state_dict(checkpoints["optimizer_g"]) - if "optimizer_d" in checkpoints: - logging.info("Loading optimizer_d state dict") - optimizer_d.load_state_dict(checkpoints["optimizer_d"]) - - # load state_dict for schedulers - if "scheduler_g" in checkpoints: - logging.info("Loading scheduler_g state dict") - scheduler_g.load_state_dict(checkpoints["scheduler_g"]) - if "scheduler_d" in checkpoints: - logging.info("Loading scheduler_d state dict") - scheduler_d.load_state_dict(checkpoints["scheduler_d"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - train_dl = libritts.train_dataloaders( - train_cuts, - world_size=world_size, - rank=rank, - ) - - valid_cuts = libritts.dev_clean_cuts() - valid_dl = libritts.valid_dataloaders( - valid_cuts, - world_size=world_size, - rank=rank, - ) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - LibriTTSCodecDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libritts/CODEC/encodec/utils.py b/egs/libritts/CODEC/encodec/utils.py deleted file mode 120000 index 7c9586776..000000000 --- a/egs/libritts/CODEC/encodec/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../vctk/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/libritts/CODEC/local/compute_spectrogram_libritts.py b/egs/libritts/CODEC/local/compute_spectrogram_libritts.py deleted file mode 100755 index 8d864db92..000000000 --- a/egs/libritts/CODEC/local/compute_spectrogram_libritts.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao,) -# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the VCTK dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/spectrogram. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import torch -from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig -from lhotse.audio import RecordingSet -from lhotse.recipes.utils import read_manifests_if_cached -from lhotse.supervision import SupervisionSet - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - ) - parser.add_argument( - "--sampling-rate", - type=int, - default=24000, - help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""", - ) - - return parser.parse_args() - - -def compute_spectrogram_libritts( - dataset: Optional[str] = None, sampling_rate: int = 24000 -): - src_dir = Path("data/manifests") - output_dir = Path("data/spectrogram") - num_jobs = min(32, os.cpu_count()) - - frame_length = 1024 / sampling_rate # (in second) - frame_shift = 256 / sampling_rate # (in second) - use_fft_mag = True - - prefix = "libritts" - suffix = "jsonl.gz" - if dataset is None: - dataset_parts = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ) - else: - dataset_parts = dataset.split(" ", -1) - - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=frame_length, - frame_shift=frame_shift, - use_fft_mag=use_fft_mag, - ) - extractor = Spectrogram(config) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if sampling_rate != 24000: - logging.info(f"Resampling audio to {sampling_rate}") - cut_set = cut_set.resample(sampling_rate) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_spectrogram_libritts() diff --git a/egs/libritts/CODEC/local/display_manifest_statistics.py b/egs/libritts/CODEC/local/display_manifest_statistics.py deleted file mode 100755 index ec00e0454..000000000 --- a/egs/libritts/CODEC/local/display_manifest_statistics.py +++ /dev/null @@ -1,341 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - paths = [ - "./data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz", - "./data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz", - "./data/spectrogram/libritts_cuts_train-other-500.jsonl.gz", - "./data/spectrogram/libritts_cuts_dev-clean.jsonl.gz", - "./data/spectrogram/libritts_cuts_dev-other.jsonl.gz", - "./data/spectrogram/libritts_cuts_test-clean.jsonl.gz", - "./data/spectrogram/libritts_cuts_test-other.jsonl.gz", - ] - for path in paths: - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -./data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 33236 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 53:47:18 _ -________________________________________ -_ mean _ 5.8 _ -________________________________________ -_ std _ 4.6 _ -________________________________________ -_ min _ 0.2 _ -________________________________________ -_ 25% _ 2.4 _ -________________________________________ -_ 50% _ 4.5 _ -________________________________________ -_ 75% _ 7.9 _ -________________________________________ -_ 99% _ 21.4 _ -________________________________________ -_ 99.5% _ 23.7 _ -________________________________________ -_ 99.9% _ 27.8 _ -________________________________________ -_ max _ 33.2 _ -________________________________________ -_ Recordings available: _ 33236 _ -________________________________________ -_ Features available: _ 33236 _ -________________________________________ -_ Supervisions available: _ 33236 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 53:47:18 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz statistics: -_________________________________________ -_ Cuts count: _ 116500 _ -_________________________________________ -_ Total duration (hh:mm:ss) _ 191:17:42 _ -_________________________________________ -_ mean _ 5.9 _ -_________________________________________ -_ std _ 4.6 _ -_________________________________________ -_ min _ 0.1 _ -_________________________________________ -_ 25% _ 2.4 _ -_________________________________________ -_ 50% _ 4.6 _ -_________________________________________ -_ 75% _ 8.1 _ -_________________________________________ -_ 99% _ 21.3 _ -_________________________________________ -_ 99.5% _ 23.4 _ -_________________________________________ -_ 99.9% _ 27.4 _ -_________________________________________ -_ max _ 40.4 _ -_________________________________________ -_ Recordings available: _ 116500 _ -_________________________________________ -_ Features available: _ 116500 _ -_________________________________________ -_ Supervisions available: _ 116500 _ -_________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -___________________________________________________________________ -_ Total speech duration _ 191:17:42 _ 100.00% of recording _ -___________________________________________________________________ -_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _ -___________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -___________________________________________________________________ - -./data/spectrogram/libritts_cuts_train-other-500.jsonl.gz statistics: -_________________________________________ -_ Cuts count: _ 205043 _ -_________________________________________ -_ Total duration (hh:mm:ss) _ 310:04:36 _ -_________________________________________ -_ mean _ 5.4 _ -_________________________________________ -_ std _ 4.4 _ -_________________________________________ -_ min _ 0.1 _ -_________________________________________ -_ 25% _ 2.3 _ -_________________________________________ -_ 50% _ 4.2 _ -_________________________________________ -_ 75% _ 7.3 _ -_________________________________________ -_ 99% _ 20.6 _ -_________________________________________ -_ 99.5% _ 22.8 _ -_________________________________________ -_ 99.9% _ 27.4 _ -_________________________________________ -_ max _ 43.9 _ -_________________________________________ -_ Recordings available: _ 205043 _ -_________________________________________ -_ Features available: _ 205043 _ -_________________________________________ -_ Supervisions available: _ 205043 _ -_________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -___________________________________________________________________ -_ Total speech duration _ 310:04:36 _ 100.00% of recording _ -___________________________________________________________________ -_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _ -___________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -___________________________________________________________________ - -./data/spectrogram/libritts_cuts_dev-clean.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 5736 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 08:58:13 _ -________________________________________ -_ mean _ 5.6 _ -________________________________________ -_ std _ 4.3 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 2.4 _ -________________________________________ -_ 50% _ 4.4 _ -________________________________________ -_ 75% _ 7.8 _ -________________________________________ -_ 99% _ 19.9 _ -________________________________________ -_ 99.5% _ 21.9 _ -________________________________________ -_ 99.9% _ 26.3 _ -________________________________________ -_ max _ 30.1 _ -________________________________________ -_ Recordings available: _ 5736 _ -________________________________________ -_ Features available: _ 5736 _ -________________________________________ -_ Supervisions available: _ 5736 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 08:58:13 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/spectrogram/libritts_cuts_dev-other.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 4613 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 06:25:52 _ -________________________________________ -_ mean _ 5.0 _ -________________________________________ -_ std _ 4.1 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 2.2 _ -________________________________________ -_ 50% _ 3.8 _ -________________________________________ -_ 75% _ 6.5 _ -________________________________________ -_ 99% _ 19.7 _ -________________________________________ -_ 99.5% _ 24.5 _ -________________________________________ -_ 99.9% _ 31.0 _ -________________________________________ -_ max _ 32.6 _ -________________________________________ -_ Recordings available: _ 4613 _ -________________________________________ -_ Features available: _ 4613 _ -________________________________________ -_ Supervisions available: _ 4613 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 06:25:52 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/spectrogram/libritts_cuts_test-clean.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 4837 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 08:34:09 _ -________________________________________ -_ mean _ 6.4 _ -________________________________________ -_ std _ 5.1 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 2.4 _ -________________________________________ -_ 50% _ 4.8 _ -________________________________________ -_ 75% _ 8.9 _ -________________________________________ -_ 99% _ 22.6 _ -________________________________________ -_ 99.5% _ 24.4 _ -________________________________________ -_ 99.9% _ 29.6 _ -________________________________________ -_ max _ 36.7 _ -________________________________________ -_ Recordings available: _ 4837 _ -________________________________________ -_ Features available: _ 4837 _ -________________________________________ -_ Supervisions available: _ 4837 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 08:34:09 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ - -./data/spectrogram/libritts_cuts_test-other.jsonl.gz statistics: -________________________________________ -_ Cuts count: _ 5120 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 06:41:31 _ -________________________________________ -_ mean _ 4.7 _ -________________________________________ -_ std _ 3.8 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 1.8 _ -________________________________________ -_ 50% _ 3.6 _ -________________________________________ -_ 75% _ 6.5 _ -________________________________________ -_ 99% _ 17.8 _ -________________________________________ -_ 99.5% _ 20.4 _ -________________________________________ -_ 99.9% _ 23.8 _ -________________________________________ -_ max _ 27.3 _ -________________________________________ -_ Recordings available: _ 5120 _ -________________________________________ -_ Features available: _ 5120 _ -________________________________________ -_ Supervisions available: _ 5120 _ -________________________________________ -SUPERVISION custom fields: -Speech duration statistics: -__________________________________________________________________ -_ Total speech duration _ 06:41:31 _ 100.00% of recording _ -__________________________________________________________________ -_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _ -__________________________________________________________________ -_ Total silence duration _ 00:00:01 _ 0.00% of recording _ -__________________________________________________________________ -""" diff --git a/egs/libritts/CODEC/local/validate_manifest.py b/egs/libritts/CODEC/local/validate_manifest.py deleted file mode 120000 index b4d52ebca..000000000 --- a/egs/libritts/CODEC/local/validate_manifest.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/libritts/CODEC/prepare.sh b/egs/libritts/CODEC/prepare.sh deleted file mode 100755 index da04249ac..000000000 --- a/egs/libritts/CODEC/prepare.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=0 -stop_stage=100 -sampling_rate=24000 -nj=32 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/LibriTTS, - # you can create a symlink - # - # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS - # - if [ ! -d $dl_dir/LibriTTS ]; then - lhotse download libritts $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare LibriTTS manifest" - # We assume that you have downloaded the LibriTTS corpus - # to $dl_dir/LibriTTS - mkdir -p data/manifests - if [ ! -e data/manifests/.libritts.done ]; then - lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests - touch data/manifests/.libritts.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute Spectrogram for LibriTTS" - mkdir -p data/spectrogram - if [ ! -e data/spectrogram/.libritts.done ]; then - ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate - touch data/spectrogram/.libritts.done - fi - - # Here we shuffle and combine the train-clean-100, train-clean-360 and - # train-other-500 together to form the training set. - if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then - cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz - fi - - if [ ! -e data/spectrogram/.libritts-validated.done ]; then - log "Validating data/spectrogram for LibriTTS" - ./local/validate_manifest.py \ - data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz - touch data/spectrogram/.libritts-validated.done - fi -fi - diff --git a/egs/libritts/CODEC/shared b/egs/libritts/CODEC/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/libritts/CODEC/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/libritts/TTS/README.md b/egs/libritts/TTS/README.md deleted file mode 100644 index 67424a1ca..000000000 --- a/egs/libritts/TTS/README.md +++ /dev/null @@ -1,102 +0,0 @@ -# Introduction - -LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. -The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. -The main differences from the LibriSpeech corpus are listed below: -1. The audio files are at 24kHz sampling rate. -2. The speech is split at sentence breaks. -3. Both original and normalized texts are included. -4. Contextual information (e.g., neighbouring sentences) can be extracted. -5. Utterances with significant background noise are excluded. -For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. - -> [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). -> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> -> By using this framework, you agree to the following: -> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> -> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> -> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> -> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. - - -# VITS - -This recipe provides a VITS model trained on the LibriTTS dataset. - -Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-libritts-vits-2024-10-30). - -The training command is given below: -``` -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -./vits/train.py \ - --world-size 4 \ - --num-epochs 400 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir vits/exp \ - --max-duration 500 -``` - -To inference, use: -``` -./vits/infer.py \ - --exp-dir vits/exp \ - --epoch 400 \ - --tokens data/tokens.txt -``` - -# [VALL-E](https://arxiv.org/abs/2301.02111) - -./valle contains the code for training VALL-E TTS model. - -Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_libritts). The demo of the model trained with libritts and [libritts-r](https://www.openslr.org/141/) is available [here](https://huggingface.co/spaces/yuekai/valle-libritts-demo). - -Preparation: - -``` -bash prepare.sh --start-stage 4 -``` - -The training command is given below: - -``` -world_size=8 -exp_dir=exp/valle - -## Train AR model -python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ - --share-embedding true --norm-first true --add-prenet false \ - --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ - --base-lr 0.03 --warmup-steps 200 --average-period 0 \ - --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ - --exp-dir ${exp_dir} --world-size ${world_size} - -## Train NAR model -# cd ${exp_dir} -# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 -# cd - -python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ - --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ - --share-embedding true --norm-first true --add-prenet false \ - --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ - --base-lr 0.03 --warmup-steps 200 --average-period 0 \ - --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ - --exp-dir ${exp_dir} --world-size ${world_size} -``` - -To inference, use: -``` -huggingface-cli login -huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_libritts -top_p=1.0 -python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ - --top-k -1 --temperature 1.0 \ - --text ./libritts.txt \ - --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt --top-p ${top_p} -``` diff --git a/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py deleted file mode 120000 index 68579ffd4..000000000 --- a/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/compute_spectrogram_libritts.py b/egs/libritts/TTS/local/compute_spectrogram_libritts.py deleted file mode 120000 index 5a6ebba58..000000000 --- a/egs/libritts/TTS/local/compute_spectrogram_libritts.py +++ /dev/null @@ -1 +0,0 @@ -../../CODEC/local/compute_spectrogram_libritts.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/prepare_token_file.py b/egs/libritts/TTS/local/prepare_token_file.py deleted file mode 120000 index afc29a22b..000000000 --- a/egs/libritts/TTS/local/prepare_token_file.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/local/prepare_token_file.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/prepare_tokens_libritts.py b/egs/libritts/TTS/local/prepare_tokens_libritts.py deleted file mode 100755 index cdc39ea6b..000000000 --- a/egs/libritts/TTS/local/prepare_tokens_libritts.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, -# Zengrui Jin,) -# 2024 Tsinghua University (authors: Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file reads the texts in given manifest and save the new cuts with phoneme tokens. -""" - -import logging -from pathlib import Path - -import tacotron_cleaner.cleaners -from lhotse import CutSet, load_manifest -from piper_phonemize import phonemize_espeak -from tqdm.auto import tqdm - - -def prepare_tokens_libritts(): - output_dir = Path("data/spectrogram") - prefix = "libritts" - suffix = "jsonl.gz" - partitions = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-all-shuf", - "train-clean-460", - # "train-clean-100", - # "train-clean-360", - # "train-other-500", - ) - - for partition in partitions: - cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - new_cuts = [] - for cut in tqdm(cut_set): - # Each cut only contains one supervision - assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) - text = cut.supervisions[0].text - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens_list = phonemize_espeak(text, "en-us") - tokens = [] - for t in tokens_list: - tokens.extend(t) - cut.tokens = tokens - cut.supervisions[0].normalized_text = text - - new_cuts.append(cut) - - new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file( - output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - prepare_tokens_libritts() diff --git a/egs/libritts/TTS/local/validate_manifest.py b/egs/libritts/TTS/local/validate_manifest.py deleted file mode 120000 index b4d52ebca..000000000 --- a/egs/libritts/TTS/local/validate_manifest.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/libritts/TTS/prepare.sh b/egs/libritts/TTS/prepare.sh deleted file mode 100755 index a0a6d2ae1..000000000 --- a/egs/libritts/TTS/prepare.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=0 -stop_stage=100 -sampling_rate=24000 -nj=32 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: build monotonic_align lib" - if [ ! -d vits/monotonic_align/build ]; then - cd vits/monotonic_align - python setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib already built" - fi -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/LibriTTS, - # you can create a symlink - # - # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS - # - if [ ! -d $dl_dir/LibriTTS ]; then - lhotse download libritts $dl_dir - fi - - if [ ! -d $dl_dir/xvector_nnet_1a_libritts_clean_460 ]; then - log "Downloading x-vector" - - git clone https://huggingface.co/datasets/zrjin/xvector_nnet_1a_libritts_clean_460 $dl_dir/xvector_nnet_1a_libritts_clean_460 - - mkdir -p exp/xvector_nnet_1a/ - cp -r $dl_dir/xvector_nnet_1a_libritts_clean_460/* exp/xvector_nnet_1a/ - fi - -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare LibriTTS manifest" - # We assume that you have downloaded the LibriTTS corpus - # to $dl_dir/LibriTTS - mkdir -p data/manifests - if [ ! -e data/manifests/.libritts.done ]; then - lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests - touch data/manifests/.libritts.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute Spectrogram for LibriTTS" - mkdir -p data/spectrogram - if [ ! -e data/spectrogram/.libritts.done ]; then - ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate - touch data/spectrogram/.libritts.done - fi - - # Here we shuffle and combine the train-clean-100, train-clean-360 and - # train-other-500 together to form the training set. - if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then - cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/spectrogram/libritts_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz - fi - - # Here we shuffle and combine the train-clean-100, train-clean-360 - # together to form the training set. - if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then - cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) | \ - shuf | gzip -c > data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz - fi - - if [ ! -e data/spectrogram/.libritts-validated.done ]; then - log "Validating data/spectrogram for LibriTTS" - ./local/validate_manifest.py \ - data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz - touch data/spectrogram/.libritts-validated.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare phoneme tokens for LibriTTS" - # We assume you have installed piper_phonemize and espnet_tts_frontend. - # If not, please install them with: - # - piper_phonemize: - # refer to https://github.com/rhasspy/piper-phonemize, - # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: - # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ - if [ ! -e data/spectrogram/.libritts_with_token.done ]; then - ./local/prepare_tokens_libritts.py - touch data/spectrogram/.libritts_with_token.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Generate token file" - # We assume you have installed piper_phonemize and espnet_tts_frontend. - # If not, please install them with: - # - piper_phonemize: - # refer to https://github.com/rhasspy/piper-phonemize, - # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: - # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ - if [ ! -e data/tokens.txt ]; then - ./local/prepare_token_file.py --tokens data/tokens.txt - fi -fi - -audio_feats_dir=data/tokenized -dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean" -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Tokenize/Fbank LibriTTS for valle" - mkdir -p ${audio_feats_dir} - if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then - python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ - --audio-extractor "Encodec" \ - --batch-duration 400 \ - --src-dir "data/manifests" \ - --output-dir "${audio_feats_dir}" - fi - touch ${audio_feats_dir}/.libritts.tokenize.done - - lhotse combine \ - ${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \ - ${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \ - ${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \ - ${audio_feats_dir}/cuts_train.jsonl.gz - lhotse copy \ - ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \ - ${audio_feats_dir}/cuts_dev.jsonl.gz - lhotse copy \ - ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \ - ${audio_feats_dir}/cuts_test.jsonl.gz -fi diff --git a/egs/libritts/TTS/shared b/egs/libritts/TTS/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/libritts/TTS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/libritts/TTS/valle b/egs/libritts/TTS/valle deleted file mode 120000 index c8fe8fdb0..000000000 --- a/egs/libritts/TTS/valle +++ /dev/null @@ -1 +0,0 @@ -../../wenetspeech4tts/TTS/valle/ \ No newline at end of file diff --git a/egs/libritts/TTS/vits/duration_predictor.py b/egs/libritts/TTS/vits/duration_predictor.py deleted file mode 120000 index 9972b476f..000000000 --- a/egs/libritts/TTS/vits/duration_predictor.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/flow.py b/egs/libritts/TTS/vits/flow.py deleted file mode 120000 index e65d91ea7..000000000 --- a/egs/libritts/TTS/vits/flow.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/generator.py b/egs/libritts/TTS/vits/generator.py deleted file mode 120000 index 611679bfa..000000000 --- a/egs/libritts/TTS/vits/generator.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/hifigan.py b/egs/libritts/TTS/vits/hifigan.py deleted file mode 120000 index 5ac025de7..000000000 --- a/egs/libritts/TTS/vits/hifigan.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/infer.py b/egs/libritts/TTS/vits/infer.py deleted file mode 100755 index 675678606..000000000 --- a/egs/libritts/TTS/vits/infer.py +++ /dev/null @@ -1,280 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script performs model inference on test set. - -Usage: -./vits/infer.py \ - --epoch 1000 \ - --exp-dir ./vits/exp \ - --max-duration 500 -""" - - -import argparse -import logging -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path -from typing import List - -import k2 -import numpy as np -import torch -import torch.nn as nn -import torchaudio -from lhotse.features.io import KaldiReader -from tokenizer import Tokenizer -from train import get_model, get_params -from tts_datamodule import LibrittsTtsDataModule - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, setup_logger - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - return parser - - -def infer_dataset( - dl: torch.utils.data.DataLoader, - subset: str, - params: AttributeDict, - model: nn.Module, - tokenizer: Tokenizer, - speaker_map: KaldiReader, -) -> None: - """Decode dataset. - The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - tokenizer: - Used to convert text to phonemes. - """ - - # Background worker save audios to disk. - def _save_worker( - subset: str, - batch_size: int, - cut_ids: List[str], - audio: torch.Tensor, - audio_pred: torch.Tensor, - audio_lens: List[int], - audio_lens_pred: List[int], - ): - for i in range(batch_size): - torchaudio.save( - str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), - audio[i : i + 1, : audio_lens[i]], - sample_rate=params.sampling_rate, - ) - torchaudio.save( - str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), - audio_pred[i : i + 1, : audio_lens_pred[i]], - sample_rate=params.sampling_rate, - ) - - device = next(model.parameters()).device - num_cuts = 0 - log_interval = 5 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - futures = [] - with ThreadPoolExecutor(max_workers=1) as executor: - for batch_idx, batch in enumerate(dl): - batch_size = len(batch["tokens"]) - - tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - audio = batch["audio"] - audio_lens = batch["audio_lens"].tolist() - cut_ids = [cut.id for cut in batch["cut"]] - sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids] - spembs = ( - torch.Tensor(np.array([speaker_map.read(sid) for sid in sids])) - .squeeze(1) - .to(device) - ) - - audio_pred, _, durations = model.inference_batch( - text=tokens, - text_lengths=tokens_lens, - spembs=spembs, - ) - audio_pred = audio_pred.detach().cpu() - # convert to samples - audio_lens_pred = ( - (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() - ) - - futures.append( - executor.submit( - _save_worker, - subset, - batch_size, - cut_ids, - audio, - audio_pred, - audio_lens, - audio_lens_pred, - ) - ) - - num_cuts += batch_size - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - # return results - for f in futures: - f.result() - - -@torch.no_grad() -def main(): - parser = get_parser() - LibrittsTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.suffix = f"epoch-{params.epoch}" - - params.res_dir = params.exp_dir / "infer" / params.suffix - params.save_wav_dir = params.res_dir / "wav" - params.save_wav_dir.mkdir(parents=True, exist_ok=True) - - setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") - logging.info("Infer started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - # we need cut ids to display recognition results. - args.return_cuts = True - libritts = LibrittsTtsDataModule(args) - - logging.info(f"Device: {device}") - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - model.to(device) - model.eval() - - num_param_g = sum([p.numel() for p in model.generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - test_clean_cuts = libritts.test_clean_cuts() - test_clean_speaker_map = libritts.test_clean_xvector() - test_clean_dl = libritts.test_dataloaders(test_clean_cuts) - - dev_clean_cuts = libritts.dev_clean_cuts() - dev_clean_speaker_map = libritts.dev_clean_xvector() - dev_clean_dl = libritts.dev_dataloaders(dev_clean_cuts) - - infer_sets = { - "test-clean": (test_clean_dl, test_clean_speaker_map), - "dev-clean": (dev_clean_dl, dev_clean_speaker_map), - } - - for subset, data in infer_sets.items(): - save_wav_dir = params.res_dir / "wav" / subset - save_wav_dir.mkdir(parents=True, exist_ok=True) - dl, speaker_map = data - - logging.info(f"Processing {subset} set, saving to {save_wav_dir}") - - infer_dataset( - dl=dl, - subset=subset, - params=params, - model=model, - tokenizer=tokenizer, - speaker_map=speaker_map, - ) - - logging.info(f"Wav files are saved to {params.save_wav_dir}") - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libritts/TTS/vits/loss.py b/egs/libritts/TTS/vits/loss.py deleted file mode 120000 index 672e5ff68..000000000 --- a/egs/libritts/TTS/vits/loss.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/monotonic_align b/egs/libritts/TTS/vits/monotonic_align deleted file mode 120000 index 71934e7cc..000000000 --- a/egs/libritts/TTS/vits/monotonic_align +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/libritts/TTS/vits/posterior_encoder.py b/egs/libritts/TTS/vits/posterior_encoder.py deleted file mode 120000 index 41d64a3a6..000000000 --- a/egs/libritts/TTS/vits/posterior_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/residual_coupling.py b/egs/libritts/TTS/vits/residual_coupling.py deleted file mode 120000 index f979adbf0..000000000 --- a/egs/libritts/TTS/vits/residual_coupling.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/test_onnx.py b/egs/libritts/TTS/vits/test_onnx.py deleted file mode 100755 index ae6587338..000000000 --- a/egs/libritts/TTS/vits/test_onnx.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to test the exported onnx model by vits/export-onnx.py - -Use the onnx model to generate a wav: -./vits/test_onnx.py \ - --model-filename vits/exp/vits-epoch-1000.onnx \ - --tokens data/tokens.txt -""" - - -import argparse -import logging -from pathlib import Path - -import onnxruntime as ort -import torch -import torchaudio -from tokenizer import Tokenizer - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the onnx model.", - ) - - parser.add_argument( - "--speakers", - type=Path, - default=Path("data/speakers.txt"), - help="Path to speakers.txt file.", - ) - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - return parser - - -class OnnxModel: - def __init__(self, model_filename: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.model = ort.InferenceSession( - model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - - def __call__( - self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor - ) -> torch.Tensor: - """ - Args: - tokens: - A 1-D tensor of shape (1, T) - Returns: - A tensor of shape (1, T') - """ - noise_scale = torch.tensor([0.667], dtype=torch.float32) - noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) - alpha = torch.tensor([1.0], dtype=torch.float32) - - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: tokens.numpy(), - self.model.get_inputs()[1].name: tokens_lens.numpy(), - self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: alpha.numpy(), - self.model.get_inputs()[4].name: noise_scale_dur.numpy(), - self.model.get_inputs()[5].name: speaker.numpy(), - }, - )[0] - return torch.from_numpy(out) - - -def main(): - args = get_parser().parse_args() - - tokenizer = Tokenizer(args.tokens) - - with open(args.speakers) as f: - speaker_map = {line.strip(): i for i, line in enumerate(f)} - args.num_spks = len(speaker_map) - - logging.info("About to create onnx model") - model = OnnxModel(args.model_filename) - - text = "I went there to see the land, the people and how their system works, end quote." - tokens = tokenizer.texts_to_token_ids( - [text], intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = torch.tensor(tokens) # (1, T) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) - speaker = torch.tensor([1], dtype=torch.int64) # (1, ) - audio = model(tokens, tokens_lens, speaker) # (1, T') - - torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) - logging.info("Saved to test_onnx.wav") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/libritts/TTS/vits/text_encoder.py b/egs/libritts/TTS/vits/text_encoder.py deleted file mode 120000 index 0efba277e..000000000 --- a/egs/libritts/TTS/vits/text_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/tokenizer.py b/egs/libritts/TTS/vits/tokenizer.py deleted file mode 120000 index 057b0dc4b..000000000 --- a/egs/libritts/TTS/vits/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/train.py b/egs/libritts/TTS/vits/train.py deleted file mode 100755 index 447fbcf5d..000000000 --- a/egs/libritts/TTS/vits/train.py +++ /dev/null @@ -1,1015 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import k2 -import numpy as np -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse.cut import Cut -from lhotse.features.io import KaldiReader -from lhotse.utils import fix_random_seed -from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import LibrittsTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, setup_logger, str2bool - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1000, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--lr", type=float, default=2.0e-4, help="The base learning rate." - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=20, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - # training params - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 200, - "env_info": get_env_info(), - "sampling_rate": 24000, - "frame_shift": 256, - "frame_length": 1024, - "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "n_mels": 80, - "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss - "lambda_mel": 45.0, # loss scaling coefficient for Mel loss - "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss - "lambda_dur": 1.0, # loss scaling coefficient for duration loss - "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def get_model(params: AttributeDict) -> nn.Module: - mel_loss_params = { - "n_mels": params.n_mels, - "frame_length": params.frame_length, - "frame_shift": params.frame_shift, - } - generator_params = { - "hidden_channels": 192, - "spks": None, - "langs": None, - "spk_embed_dim": 512, - "global_channels": 256, - "segment_size": 32, - "text_encoder_attention_heads": 2, - "text_encoder_ffn_expand": 4, - "text_encoder_cnn_module_kernel": 5, - "text_encoder_blocks": 6, - "text_encoder_dropout_rate": 0.1, - "decoder_kernel_size": 7, - "decoder_channels": 512, - "decoder_upsample_scales": [8, 8, 2, 2], - "decoder_upsample_kernel_sizes": [16, 16, 4, 4], - "decoder_resblock_kernel_sizes": [3, 7, 11], - "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "use_weight_norm_in_decoder": True, - "posterior_encoder_kernel_size": 5, - "posterior_encoder_layers": 16, - "posterior_encoder_stacks": 1, - "posterior_encoder_base_dilation": 1, - "posterior_encoder_dropout_rate": 0.0, - "use_weight_norm_in_posterior_encoder": True, - "flow_flows": 4, - "flow_kernel_size": 5, - "flow_base_dilation": 1, - "flow_layers": 4, - "flow_dropout_rate": 0.0, - "use_weight_norm_in_flow": True, - "use_only_mean_in_flow": True, - "stochastic_duration_predictor_kernel_size": 3, - "stochastic_duration_predictor_dropout_rate": 0.5, - "stochastic_duration_predictor_flows": 4, - "stochastic_duration_predictor_dds_conv_layers": 3, - } - model = VITS( - vocab_size=params.vocab_size, - feature_dim=params.feature_dim, - sampling_rate=params.sampling_rate, - generator_params=generator_params, - mel_loss_params=mel_loss_params, - lambda_adv=params.lambda_adv, - lambda_mel=params.lambda_mel, - lambda_feat_match=params.lambda_feat_match, - lambda_dur=params.lambda_dur, - lambda_kl=params.lambda_kl, - ) - return model - - -def prepare_input( - batch: dict, - tokenizer: Tokenizer, - device: torch.device, - speaker_map: KaldiReader, -): - """Parse batch data""" - - def parse_sids(batch: dict) -> List[str]: - return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]] - - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - tokens = batch["tokens"] - spembs = ( - torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)])) - .squeeze(1) - .to(device) - ) - - tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - optimizer_g: Optimizer, - optimizer_d: Optimizer, - scheduler_g: LRSchedulerType, - scheduler_d: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - dev_dl: torch.utils.data.DataLoader, - train_speaker_map: KaldiReader, - dev_speaker_map: KaldiReader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - tokenizer: - Used to convert text to phonemes. - optimizer_g: - The optimizer for generator. - optimizer_d: - The optimizer for discriminator. - scheduler_g: - The learning rate scheduler for generator, we call step() every epoch. - scheduler_d: - The learning rate scheduler for discriminator, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - - batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - spembs, - ) = prepare_input(batch, tokenizer, device, train_speaker_map) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - try: - with autocast(enabled=params.use_fp16): - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - spembs=spembs, - forward_generator=False, - ) - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - # update discriminator - optimizer_d.zero_grad() - scaler.scale(loss_d).backward() - scaler.step(optimizer_d) - - with autocast(enabled=params.use_fp16): - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - spembs=spembs, - forward_generator=True, - return_sample=params.batch_idx_train % params.log_interval == 0, - ) - for k, v in stats_g.items(): - if "returned_sample" not in k: - loss_info[k] = v * batch_size - # update generator - optimizer_g.zero_grad() - scaler.scale(loss_g).backward() - scaler.step(optimizer_g) - scaler.update() - - # summary stats - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr_g = max(scheduler_g.get_last_lr()) - cur_lr_d = max(scheduler_d.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate_g", cur_lr_g, params.batch_idx_train - ) - tb_writer.add_scalar( - "train/learning_rate_d", cur_lr_d, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - if "returned_sample" in stats_g: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] - tb_writer.add_audio( - "train/speech_hat_", - speech_hat_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/speech_", - speech_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_image( - "train/mel_hat_", - plot_feature(mel_hat_), - params.batch_idx_train, - dataformats="HWC", - ) - tb_writer.add_image( - "train/mel_", - plot_feature(mel_), - params.batch_idx_train, - dataformats="HWC", - ) - - if ( - params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info, (speech_hat, speech) = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - dev_dl=dev_dl, - dev_speaker_map=dev_speaker_map, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - tb_writer.add_audio( - "train/valid_speech_hat", - speech_hat, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valid_speech", - speech, - params.batch_idx_train, - params.sampling_rate, - ) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - dev_dl: torch.utils.data.DataLoader, - dev_speaker_map: KaldiReader, - world_size: int = 1, - rank: int = 0, -) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - returned_sample = None - - with torch.no_grad(): - for batch_idx, batch in enumerate(dev_dl): - batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - spembs, - ) = prepare_input(batch, tokenizer, device, dev_speaker_map) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - spembs=spembs, - forward_generator=False, - ) - assert loss_d.requires_grad is False - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - spembs=spembs, - forward_generator=True, - ) - assert loss_g.requires_grad is False - for k, v in stats_g.items(): - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - # infer for first batch: - if batch_idx == 0 and rank == 0: - inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference( - text=tokens[0, : tokens_lens[0].item()], - spembs=spembs[0], - ) - audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = ( - (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - ) - assert audio_len_pred == len(audio_pred), ( - audio_len_pred, - len(audio_pred), - ) - audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() - returned_sample = (audio_pred, audio_gt) - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss, returned_sample - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - tokenizer: Tokenizer, - optimizer_g: torch.optim.Optimizer, - optimizer_d: torch.optim.Optimizer, - train_speaker_map: KaldiReader, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - spembs, - ) = prepare_input(batch, tokenizer, device, train_speaker_map) - try: - # for discriminator - with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - spembs=spembs, - forward_generator=False, - ) - optimizer_d.zero_grad() - loss_d.backward() - # for generator - with autocast(enabled=params.use_fp16): - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - spembs=spembs, - forward_generator=True, - ) - optimizer_g.zero_grad() - loss_g.backward() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - libritts = LibrittsTtsDataModule(args) - - if params.full_libri: - train_cuts = libritts.train_all_shuf_cuts() - train_speaker_map = libritts.train_all_shuf_xvector() - else: - train_cuts = libritts.train_clean_460_cuts() - train_speaker_map = libritts.train_clean_460_xvector() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - generator = model.generator - discriminator = model.discriminator - - num_param_g = sum([p.numel() for p in generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer_g = torch.optim.AdamW( - generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - optimizer_d = torch.optim.AdamW( - discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) - - if checkpoints is not None: - # load state_dict for optimizers - if "optimizer_g" in checkpoints: - logging.info("Loading optimizer_g state dict") - optimizer_g.load_state_dict(checkpoints["optimizer_g"]) - if "optimizer_d" in checkpoints: - logging.info("Loading optimizer_d state dict") - optimizer_d.load_state_dict(checkpoints["optimizer_d"]) - - # load state_dict for schedulers - if "scheduler_g" in checkpoints: - logging.info("Loading scheduler_g state dict") - scheduler_g.load_state_dict(checkpoints["scheduler_g"]) - if "scheduler_d" in checkpoints: - logging.info("Loading scheduler_d state dict") - scheduler_d.load_state_dict(checkpoints["scheduler_d"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = libritts.train_dataloaders(train_cuts) - - dev_clean_cuts = libritts.dev_clean_cuts() - dev_speaker_map = libritts.dev_clean_xvector() - dev_dl = libritts.dev_dataloaders(dev_clean_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - train_speaker_map=train_speaker_map, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - train_dl=train_dl, - dev_dl=dev_dl, - train_speaker_map=train_speaker_map, - dev_speaker_map=dev_speaker_map, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - # step per epoch - scheduler_g.step() - scheduler_d.step() - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - LibrittsTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libritts/TTS/vits/transform.py b/egs/libritts/TTS/vits/transform.py deleted file mode 120000 index 962647408..000000000 --- a/egs/libritts/TTS/vits/transform.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/tts_datamodule.py b/egs/libritts/TTS/vits/tts_datamodule.py deleted file mode 100644 index e98e49c1f..000000000 --- a/egs/libritts/TTS/vits/tts_datamodule.py +++ /dev/null @@ -1,432 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.features.io import KaldiReader -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -LIBRITTS_SAMPLING_RATE = 24000 - - -class LibrittsTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--full-libri", - type=str2bool, - default=False, - help="""When enabled, use the entire LibriTTS training set. - Otherwise, use the 460h clean subset.""", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--speaker-embeds", - type=Path, - default=Path("exp/xvector_nnet_1a/"), - help="Path to directory with speaker embeddings.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=8, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = LIBRITTS_SAMPLING_RATE - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - train = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = LIBRITTS_SAMPLING_RATE - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - validate = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - dev_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create valid dataloader") - dev_dl = DataLoader( - validate, - sampler=dev_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return dev_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = LIBRITTS_SAMPLING_RATE - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - test = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def train_clean_460_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100 and train-clean-360 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir - / "libritts_cuts_with_tokens_train-clean-460.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_xvector(self) -> KaldiReader: - raise NotImplementedError( - "Please implement the method to load speaker embeddings." - ) - - @lru_cache() - def train_clean_460_xvector(self) -> KaldiReader: - logging.info("About to get speaker embeddings for train-clean-460") - return KaldiReader( - str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp") - ) - - @lru_cache() - def train_clean_100_xvector(self) -> KaldiReader: - raise NotImplementedError( - "Please implement the method to load speaker embeddings." - ) - - @lru_cache() - def train_clean_360_xvector(self) -> KaldiReader: - raise NotImplementedError( - "Please implement the method to load speaker embeddings." - ) - - @lru_cache() - def train_other_500_xvector(self) -> KaldiReader: - raise NotImplementedError( - "Please implement the method to load speaker embeddings." - ) - - @lru_cache() - def dev_clean_xvector(self) -> KaldiReader: - logging.info("About to get speaker embeddings for dev-clean") - return KaldiReader( - str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp") - ) - - @lru_cache() - def dev_other_xvector(self) -> KaldiReader: - raise NotImplementedError( - "Please implement the method to load speaker embeddings." - ) - - @lru_cache() - def test_clean_xvector(self) -> KaldiReader: - logging.info("About to get speaker embeddings for test-clean") - return KaldiReader( - str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp") - ) - - @lru_cache() - def test_other_xvector(self) -> KaldiReader: - raise NotImplementedError( - "Please implement the method to load speaker embeddings." - ) diff --git a/egs/libritts/TTS/vits/utils.py b/egs/libritts/TTS/vits/utils.py deleted file mode 120000 index 085e764b4..000000000 --- a/egs/libritts/TTS/vits/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/vits.py b/egs/libritts/TTS/vits/vits.py deleted file mode 120000 index 1f58cf6fe..000000000 --- a/egs/libritts/TTS/vits/vits.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/wavenet.py b/egs/libritts/TTS/vits/wavenet.py deleted file mode 120000 index 28f0a78ee..000000000 --- a/egs/libritts/TTS/vits/wavenet.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/.gitignore b/egs/ljspeech/TTS/.gitignore deleted file mode 100644 index d5c19797a..000000000 --- a/egs/ljspeech/TTS/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -build -core.c -*.so -my-output* -*.wav -*.onnx -generator_v* diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md deleted file mode 100644 index f5495eeaf..000000000 --- a/egs/ljspeech/TTS/README.md +++ /dev/null @@ -1,230 +0,0 @@ -# Introduction - -This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. -A transcription is provided for each clip. -Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours. - -The texts were published between 1884 and 1964, and are in the public domain. -The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain. - -The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/). - -# VITS - -This recipe provides a VITS model trained on the LJSpeech dataset. - -Pretrained model can be found [here](https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28). - -For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html). - -The training command is given below: -``` -export CUDA_VISIBLE_DEVICES=0,1,2,3 -./vits/train.py \ - --world-size 4 \ - --num-epochs 1000 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir vits/exp \ - --max-duration 500 -``` - -To inference, use: -``` -./vits/infer.py \ - --exp-dir vits/exp \ - --epoch 1000 \ - --tokens data/tokens.txt -``` - -## Quality vs speed - -If you feel that the trained model is slow at runtime, you can specify the -argument `--model-type` during training. Possible values are: - - - `low`, means **low** quality. The resulting model is very small in file size - and runs very fast. The following is a wave file generatd by a `low` quality model - - https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633 - - The text is `Ask not what your country can do for you; ask what you can do for your country.` - - The exported onnx model has a file size of ``26.8 MB`` (float32). - - - `medium`, means **medium** quality. - The following is a wave file generatd by a `medium` quality model - - https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac - - The text is `Ask not what your country can do for you; ask what you can do for your country.` - - The exported onnx model has a file size of ``70.9 MB`` (float32). - - - `high`, means **high** quality. This is the default value. - - The following is a wave file generatd by a `high` quality model - - https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc - - The text is `Ask not what your country can do for you; ask what you can do for your country.` - - The exported onnx model has a file size of ``113 MB`` (float32). - - -A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at - - -```bash -export CUDA_VISIBLE_DEVICES=0,1,2,3 -./vits/train.py \ - --world-size 4 \ - --num-epochs 1601 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir vits/exp \ - --model-type low \ - --max-duration 800 -``` - -A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at - -```bash -export CUDA_VISIBLE_DEVICES=4,5,6,7 -./vits/train.py \ - --world-size 4 \ - --num-epochs 1000 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir vits/exp-medium \ - --model-type medium \ - --max-duration 500 - -# (Note it is killed after `epoch-820.pt`) -``` -# matcha - -[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS) - -This recipe provides a Matcha-TTS model trained on the LJSpeech dataset. - -Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28). -The pull-request for this recipe can be found at - -The training command is given below: -```bash -export CUDA_VISIBLE_DEVICES=0,1,2,3 - -python3 ./matcha/train.py \ - --exp-dir ./matcha/exp-new-3/ \ - --num-workers 4 \ - --world-size 4 \ - --num-epochs 4000 \ - --max-duration 1000 \ - --bucketing-sampler 1 \ - --start-epoch 1 -``` - -To inference, use: - -```bash -# Download Hifigan vocoder. We use Hifigan v1 below. You can select from v1, v2, or v3 - -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - -./matcha/infer.py \ - --exp-dir ./matcha/exp-new-3 \ - --epoch 4000 \ - --tokens ./data/tokens.txt \ - --vocoder ./generator_v1 \ - --input-text "how are you doing?" \ - --output-wav ./generated.wav -``` - -```bash -soxi ./generated.wav -``` -prints: -``` -Input File : './generated.wav' -Channels : 1 -Sample Rate : 22050 -Precision : 16-bit -Duration : 00:00:01.29 = 28416 samples ~ 96.6531 CDDA sectors -File Size : 56.9k -Bit Rate : 353k -Sample Encoding: 16-bit Signed Integer PCM -``` - -To export the checkpoint to onnx: - -```bash -# export the acoustic model to onnx - -./matcha/export_onnx.py \ - --exp-dir ./matcha/exp-new-3 \ - --epoch 4000 \ - --tokens ./data/tokens.txt -``` - -The above command generates the following files: - - - model-steps-2.onnx - - model-steps-3.onnx - - model-steps-4.onnx - - model-steps-5.onnx - - model-steps-6.onnx - -where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. - -**HINT**: If you get the following error while running `export_onnx.py`: - -``` -torch.onnx.errors.UnsupportedOperatorError: Exporting the operator -'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported. -``` - -please use `torch>=2.2.0`. - - -To export the Hifigan vocoder to onnx, please use: - -```bash -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 -wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 - -python3 ./matcha/export_onnx_hifigan.py -``` - -The above command generates 3 files: - - - hifigan_v1.onnx - - hifigan_v2.onnx - - hifigan_v3.onnx - -To use the generated onnx files to generate speech from text, please run: - -```bash -python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v1.onnx \ - --tokens ./data/tokens.txt \ - --input-text "Ask not what your country can do for you; ask what you can do for your country." \ - --output-wav ./matcha-epoch-4000-step6-hfigian-v1.wav -``` - -```bash -soxi ./matcha-epoch-4000-step6-hfigian-v1.wav - -Input File : './matcha-epoch-4000-step6-hfigian-v1.wav' -Channels : 1 -Sample Rate : 22050 -Precision : 16-bit -Duration : 00:00:05.46 = 120320 samples ~ 409.252 CDDA sectors -File Size : 241k -Bit Rate : 353k -Sample Encoding: 16-bit Signed Integer PCM -``` - -https://github.com/user-attachments/assets/b7c197a6-3870-49c6-90ca-db4d3776869b - diff --git a/egs/ljspeech/TTS/local/audio.py b/egs/ljspeech/TTS/local/audio.py deleted file mode 120000 index b70d91c92..000000000 --- a/egs/ljspeech/TTS/local/audio.py +++ /dev/null @@ -1 +0,0 @@ -../matcha/audio.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py deleted file mode 100755 index 296f9a4f4..000000000 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the LJSpeech dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from fbank import MatchaFbank, MatchaFbankConfig -from lhotse import CutSet, LilcomChunkyWriter, load_manifest -from lhotse.audio import RecordingSet -from lhotse.supervision import SupervisionSet - -from icefall.utils import get_executor - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--num-jobs", - type=int, - default=4, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - return parser - - -def compute_fbank_ljspeech(num_jobs: int): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - if num_jobs < 1: - num_jobs = os.cpu_count() - - logging.info(f"num_jobs: {num_jobs}") - logging.info(f"src_dir: {src_dir}") - logging.info(f"output_dir: {output_dir}") - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=22050, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - - prefix = "ljspeech" - suffix = "jsonl.gz" - partition = "all" - - recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet - ) - supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet - ) - - extractor = MatchaFbank(config) - - with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{cuts_filename} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=recordings, supervisions=supervisions - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - # Torch's multithreaded behavior needs to be disabled or - # it wastes a lot of CPU and slow things down. - # Do this outside of main() in case it needs to take effect - # even when we are not invoking the main (e.g. when spawning subprocesses). - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_parser().parse_args() - compute_fbank_ljspeech(args.num_jobs) diff --git a/egs/ljspeech/TTS/local/compute_fbank_statistics.py b/egs/ljspeech/TTS/local/compute_fbank_statistics.py deleted file mode 100755 index d0232c983..000000000 --- a/egs/ljspeech/TTS/local/compute_fbank_statistics.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script compute the mean and std of the fbank features. -""" - -import argparse -import json -import logging -from pathlib import Path - -import torch -from lhotse import CutSet, load_manifest_lazy - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - parser.add_argument( - "cmvn", - type=Path, - help="Path to the cmvn.json", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info( - f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}" - ) - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet), type(cut_set) - - feat_dim = cut_set[0].features.num_features - num_frames = 0 - s = 0 - sq = 0 - for c in cut_set: - f = torch.from_numpy(c.load_features()) - num_frames += f.shape[0] - s += f.sum() - sq += f.square().sum() - - fbank_mean = s / (num_frames * feat_dim) - fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean - print("fbank var", fbank_var) - fbank_std = fbank_var.sqrt() - with open(args.cmvn, "w") as f: - json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f) - f.write("\n") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py deleted file mode 100755 index 97c9008fc..000000000 --- a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the LJSpeech dataset. -It looks for manifests in the directory data/manifests. - -The generated spectrogram features are saved in data/spectrogram. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - LilcomChunkyWriter, - Spectrogram, - SpectrogramConfig, - load_manifest, -) -from lhotse.audio import RecordingSet -from lhotse.supervision import SupervisionSet - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_spectrogram_ljspeech(): - src_dir = Path("data/manifests") - output_dir = Path("data/spectrogram") - num_jobs = min(4, os.cpu_count()) - - sampling_rate = 22050 - frame_length = 1024 / sampling_rate # (in second) - frame_shift = 256 / sampling_rate # (in second) - use_fft_mag = True - - prefix = "ljspeech" - suffix = "jsonl.gz" - partition = "all" - - recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet - ) - supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet - ) - - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=frame_length, - frame_shift=frame_shift, - use_fft_mag=use_fft_mag, - ) - extractor = Spectrogram(config) - - with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{cuts_filename} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=recordings, supervisions=supervisions - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_spectrogram_ljspeech() diff --git a/egs/ljspeech/TTS/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py deleted file mode 100755 index 93f0044f0..000000000 --- a/egs/ljspeech/TTS/local/display_manifest_statistics.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in vits/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz" - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Cut statistics: - ╒═══════════════════════════╤══════════╕ - │ Cuts count: │ 13100 │ - ├───────────────────────────┼──────────┤ - │ Total duration (hh:mm:ss) │ 23:55:18 │ - ├───────────────────────────┼──────────┤ - │ mean │ 6.6 │ - ├───────────────────────────┼──────────┤ - │ std │ 2.2 │ - ├───────────────────────────┼──────────┤ - │ min │ 1.1 │ - ├───────────────────────────┼──────────┤ - │ 25% │ 5.0 │ - ├───────────────────────────┼──────────┤ - │ 50% │ 6.8 │ - ├───────────────────────────┼──────────┤ - │ 75% │ 8.4 │ - ├───────────────────────────┼──────────┤ - │ 99% │ 10.0 │ - ├───────────────────────────┼──────────┤ - │ 99.5% │ 10.1 │ - ├───────────────────────────┼──────────┤ - │ 99.9% │ 10.1 │ - ├───────────────────────────┼──────────┤ - │ max │ 10.1 │ - ├───────────────────────────┼──────────┤ - │ Recordings available: │ 13100 │ - ├───────────────────────────┼──────────┤ - │ Features available: │ 13100 │ - ├───────────────────────────┼──────────┤ - │ Supervisions available: │ 13100 │ - ╘═══════════════════════════╧══════════╛ -""" diff --git a/egs/ljspeech/TTS/local/fbank.py b/egs/ljspeech/TTS/local/fbank.py deleted file mode 120000 index 5bcf1fde5..000000000 --- a/egs/ljspeech/TTS/local/fbank.py +++ /dev/null @@ -1 +0,0 @@ -../matcha/fbank.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py deleted file mode 100755 index 5b048b600..000000000 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file generates the file that maps tokens to IDs. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict - -from piper_phonemize import get_espeak_map - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--tokens", - type=Path, - default=Path("data/tokens.txt"), - help="Path to the dict that maps the text tokens to IDs", - ) - - return parser.parse_args() - - -def get_token2id(filename: Path) -> Dict[str, int]: - """Get a dict that maps token to IDs, and save it to the given filename.""" - all_tokens = get_espeak_map() # token: [token_id] - all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()} - # sort by token_id - all_tokens = sorted(all_tokens.items(), key=lambda x: x[1]) - - with open(filename, "w", encoding="utf-8") as f: - for token, token_id in all_tokens: - f.write(f"{token} {token_id}\n") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - out_file = Path(args.tokens) - get_token2id(out_file) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py deleted file mode 100755 index 33a8ac2ab..000000000 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file reads the texts in given manifest and save the new cuts with phoneme tokens. -""" - -import logging -from pathlib import Path - -try: - import tacotron_cleaner.cleaners -except ModuleNotFoundError as ex: - raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n") - -import argparse - -from lhotse import CutSet, load_manifest -from piper_phonemize import phonemize_espeak - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--in-out-dir", - type=Path, - required=True, - help="Input and output directory", - ) - - return parser - - -def prepare_tokens_ljspeech(in_out_dir): - prefix = "ljspeech" - suffix = "jsonl.gz" - partition = "all" - - cut_set = load_manifest(in_out_dir / f"{prefix}_cuts_{partition}.{suffix}") - - new_cuts = [] - for cut in cut_set: - # Each cut only contains one supervision - assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) - text = cut.supervisions[0].normalized_text - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens_list = phonemize_espeak(text, "en-us") - tokens = [] - for t in tokens_list: - tokens.extend(t) - cut.tokens = tokens - new_cuts.append(cut) - - new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file(in_out_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_parser().parse_args() - - prepare_tokens_ljspeech(args.in_out_dir) diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py deleted file mode 100755 index 68159ae03..000000000 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/spectrogram/ljspeech_cuts_all.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset.speech_synthesis import validate_for_tts - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet), type(cut_set) - - validate_for_tts(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ljspeech/TTS/matcha/LICENSE b/egs/ljspeech/TTS/matcha/LICENSE deleted file mode 100644 index 858018e75..000000000 --- a/egs/ljspeech/TTS/matcha/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 Shivam Mehta - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/egs/ljspeech/TTS/matcha/__init__.py b/egs/ljspeech/TTS/matcha/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/ljspeech/TTS/matcha/audio.py b/egs/ljspeech/TTS/matcha/audio.py deleted file mode 100644 index 534331e59..000000000 --- a/egs/ljspeech/TTS/matcha/audio.py +++ /dev/null @@ -1,92 +0,0 @@ -# This file is copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py -import numpy as np -import torch -import torch.utils.data -from librosa.filters import mel as librosa_mel_fn -from scipy.io.wavfile import read - -MAX_WAV_VALUE = 32768.0 - - -def load_wav(full_path): - sampling_rate, data = read(full_path) - return data, sampling_rate - - -def dynamic_range_compression(x, C=1, clip_val=1e-5): - return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) - - -def dynamic_range_decompression(x, C=1): - return np.exp(x) / C - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - -mel_basis = {} -hann_window = {} - - -def mel_spectrogram( - y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False -): - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window # pylint: disable=global-statement - if f"{str(fmax)}_{str(y.device)}" not in mel_basis: - mel = librosa_mel_fn( - sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax - ) - mel_basis[str(fmax) + "_" + str(y.device)] = ( - torch.from_numpy(mel).float().to(y.device) - ) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - - spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py deleted file mode 100755 index 3c653fbf1..000000000 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -""" -This script exports a Matcha-TTS model to ONNX. -Note that the model outputs fbank. You need to use a vocoder to convert -it to audio. See also ./export_onnx_hifigan.py -""" - -import argparse -import json -import logging -from pathlib import Path -from typing import Any, Dict - -import onnx -import torch -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=4000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp-new-3", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, Any]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - - while len(model.metadata_props): - model.metadata_props.pop() - - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -class ModelWrapper(torch.nn.Module): - def __init__(self, model, num_steps: int = 5): - super().__init__() - self.model = model - self.num_steps = num_steps - - def forward( - self, - x: torch.Tensor, - x_lengths: torch.Tensor, - noise_scale: torch.Tensor, - length_scale: torch.Tensor, - ) -> torch.Tensor: - """ - Args: : - x: (batch_size, num_tokens), torch.int64 - x_lengths: (batch_size,), torch.int64 - noise_scale: (1,), torch.float32 - length_scale (1,), torch.float32 - Returns: - audio: (batch_size, num_samples) - - """ - mel = self.model.synthesise( - x=x, - x_lengths=x_lengths, - n_timesteps=self.num_steps, - temperature=noise_scale, - length_scale=length_scale, - )["mel"] - # mel: (batch_size, feat_dim, num_frames) - - return mel - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - for num_steps in [2, 3, 4, 5, 6]: - logging.info(f"num_steps: {num_steps}") - wrapper = ModelWrapper(model, num_steps=num_steps) - wrapper.eval() - - # Use a large value so the rotary position embedding in the text - # encoder has a large initial length - x = torch.ones(1, 1000, dtype=torch.int64) - x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - noise_scale = torch.tensor([1.0]) - length_scale = torch.tensor([1.0]) - - opset_version = 14 - filename = f"model-steps-{num_steps}.onnx" - torch.onnx.export( - wrapper, - (x, x_lengths, noise_scale, length_scale), - filename, - opset_version=opset_version, - input_names=["x", "x_length", "noise_scale", "length_scale"], - output_names=["mel"], - dynamic_axes={ - "x": {0: "N", 1: "L"}, - "x_length": {0: "N"}, - "mel": {0: "N", 2: "L"}, - }, - ) - - meta_data = { - "model_type": "matcha-tts", - "language": "English", - "voice": "en-us", - "has_espeak": 1, - "jieba": 0, - "n_speakers": 1, - "sample_rate": 22050, - "version": 1, - "pad_id": tokenizer.pad_id, - "model_author": "icefall", - "maintainer": "k2-fsa", - "use_eos_bos": 1, - "dataset": "LJ Speech", - "dataset_url": "https://keithito.com/LJ-Speech-Dataset/", - "num_ode_steps": num_steps, - } - add_meta_data(filename=filename, meta_data=meta_data) - print(meta_data) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py deleted file mode 100755 index 5c96b3bc7..000000000 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import logging -from pathlib import Path -from typing import Any, Dict - -import onnx -import torch -from infer import load_vocoder - - -def add_meta_data(filename: str, meta_data: Dict[str, Any]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - - while len(model.metadata_props): - model.metadata_props.pop() - - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -class ModelWrapper(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward( - self, - mel: torch.Tensor, - ) -> torch.Tensor: - """ - Args: : - mel: (batch_size, feat_dim, num_frames), torch.float32 - Returns: - audio: (batch_size, num_samples), torch.float32 - """ - audio = self.model(mel).clamp(-1, 1).squeeze(1) - return audio - - -@torch.inference_mode() -def main(): - # Please go to - # https://github.com/csukuangfj/models/tree/master/hifigan - # to download the following files - model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"] - - for f in model_filenames: - logging.info(f) - if not Path(f).is_file(): - logging.info(f"Skipping {f} since {f} does not exist") - continue - model = load_vocoder(f) - wrapper = ModelWrapper(model) - wrapper.eval() - num_param = sum([p.numel() for p in wrapper.parameters()]) - logging.info(f"{f}: Number of parameters: {num_param}") - - # Use a large value so the rotary position embedding in the text - # encoder has a large initial length - x = torch.ones(1, 80, 100000, dtype=torch.float32) - opset_version = 14 - suffix = f.split("_")[-1] - filename = f"hifigan_{suffix}.onnx" - torch.onnx.export( - wrapper, - x, - filename, - opset_version=opset_version, - input_names=["mel"], - output_names=["audio"], - dynamic_axes={ - "mel": {0: "N", 2: "L"}, - "audio": {0: "N", 1: "L"}, - }, - ) - - meta_data = { - "model_type": "hifigan", - "model_filename": f.split("/")[-1], - "sample_rate": 22050, - "version": 1, - "model_author": "jik876", - "maintainer": "k2-fsa", - "dataset": "LJ Speech", - "url1": "https://github.com/jik876/hifi-gan", - "url2": "https://github.com/csukuangfj/models/tree/master/hifigan", - } - add_meta_data(filename=filename, meta_data=meta_data) - print(meta_data) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ljspeech/TTS/matcha/fbank.py b/egs/ljspeech/TTS/matcha/fbank.py deleted file mode 100644 index cc94a301f..000000000 --- a/egs/ljspeech/TTS/matcha/fbank.py +++ /dev/null @@ -1,89 +0,0 @@ -from dataclasses import dataclass -from typing import Union - -import numpy as np -import torch -from audio import mel_spectrogram -from lhotse.features.base import FeatureExtractor, register_extractor -from lhotse.utils import Seconds, compute_num_frames - - -@dataclass -class MatchaFbankConfig: - n_fft: int - n_mels: int - sampling_rate: int - hop_length: int - win_length: int - f_min: float - f_max: float - device: str = "cuda" - - -@register_extractor -class MatchaFbank(FeatureExtractor): - - name = "MatchaFbank" - config_type = MatchaFbankConfig - - def __init__(self, config): - super().__init__(config=config) - - @property - def device(self) -> Union[str, torch.device]: - return self.config.device - - def feature_dim(self, sampling_rate: int) -> int: - return self.config.n_mels - - def extract( - self, - samples: np.ndarray, - sampling_rate: int, - ) -> torch.Tensor: - # Check for sampling rate compatibility. - expected_sr = self.config.sampling_rate - assert sampling_rate == expected_sr, ( - f"Mismatched sampling rate: extractor expects {expected_sr}, " - f"got {sampling_rate}" - ) - samples = torch.from_numpy(samples).to(self.device) - assert samples.ndim == 2, samples.shape - assert samples.shape[0] == 1, samples.shape - - mel = ( - mel_spectrogram( - samples, - self.config.n_fft, - self.config.n_mels, - self.config.sampling_rate, - self.config.hop_length, - self.config.win_length, - self.config.f_min, - self.config.f_max, - center=False, - ) - .squeeze() - .t() - ) - - assert mel.ndim == 2, mel.shape - assert mel.shape[1] == self.config.n_mels, mel.shape - - num_frames = compute_num_frames( - samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate - ) - - if mel.shape[0] > num_frames: - mel = mel[:num_frames] - elif mel.shape[0] < num_frames: - mel = mel.unsqueeze(0) - mel = torch.nn.functional.pad( - mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" - ).squeeze(0) - - return mel.cpu().numpy() - - @property - def frame_shift(self) -> Seconds: - return self.config.hop_length / self.config.sampling_rate diff --git a/egs/ljspeech/TTS/matcha/hifigan/LICENSE b/egs/ljspeech/TTS/matcha/hifigan/LICENSE deleted file mode 100644 index 91751daed..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2020 Jungil Kong - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/egs/ljspeech/TTS/matcha/hifigan/README.md b/egs/ljspeech/TTS/matcha/hifigan/README.md deleted file mode 100644 index 5db258504..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/README.md +++ /dev/null @@ -1,101 +0,0 @@ -# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis - -### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae - -In our [paper](https://arxiv.org/abs/2010.05646), -we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
-We provide our implementation and pretrained models as open source in this repository. - -**Abstract :** -Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. -Although such methods improve the sampling efficiency and memory usage, -their sample quality has not yet reached that of autoregressive and flow-based generative models. -In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. -As speech audio consists of sinusoidal signals with various periods, -we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. -A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method -demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than -real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen -speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times -faster than real-time on CPU with comparable quality to an autoregressive counterpart. - -Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. - -## Pre-requisites - -1. Python >= 3.6 -2. Clone this repository. -3. Install python requirements. Please refer [requirements.txt](requirements.txt) -4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). - And move all wav files to `LJSpeech-1.1/wavs` - -## Training - -``` -python train.py --config config_v1.json -``` - -To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
-Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
-You can change the path by adding `--checkpoint_path` option. - -Validation loss during training with V1 generator.
-![validation loss](./validation_loss.png) - -## Pretrained Model - -You can also use pretrained models we provide.
-[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
-Details of each folder are as in follows: - -| Folder Name | Generator | Dataset | Fine-Tuned | -| ------------ | --------- | --------- | ------------------------------------------------------ | -| LJ_V1 | V1 | LJSpeech | No | -| LJ_V2 | V2 | LJSpeech | No | -| LJ_V3 | V3 | LJSpeech | No | -| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | -| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | -| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | -| VCTK_V1 | V1 | VCTK | No | -| VCTK_V2 | V2 | VCTK | No | -| VCTK_V3 | V3 | VCTK | No | -| UNIVERSAL_V1 | V1 | Universal | No | - -We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. - -## Fine-Tuning - -1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
- The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
- Example: - ` Audio File : LJ001-0001.wav -Mel-Spectrogram File : LJ001-0001.npy` -2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
-3. Run the following command. - ``` - python train.py --fine_tuning True --config config_v1.json - ``` - For other command line options, please refer to the training section. - -## Inference from wav file - -1. Make `test_files` directory and copy wav files into the directory. -2. Run the following command. - ` python inference.py --checkpoint_file [generator checkpoint file path]` - Generated wav files are saved in `generated_files` by default.
- You can change the path by adding `--output_dir` option. - -## Inference for end-to-end speech synthesis - -1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
- You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), - [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. -2. Run the following command. - ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` - Generated wav files are saved in `generated_files_from_mel` by default.
- You can change the path by adding `--output_dir` option. - -## Acknowledgements - -We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) -and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. diff --git a/egs/ljspeech/TTS/matcha/hifigan/__init__.py b/egs/ljspeech/TTS/matcha/hifigan/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/ljspeech/TTS/matcha/hifigan/config.py b/egs/ljspeech/TTS/matcha/hifigan/config.py deleted file mode 100644 index ecba62fd4..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/config.py +++ /dev/null @@ -1,100 +0,0 @@ -v1 = { - "resblock": "1", - "num_gpus": 0, - "batch_size": 16, - "learning_rate": 0.0004, - "adam_b1": 0.8, - "adam_b2": 0.99, - "lr_decay": 0.999, - "seed": 1234, - "upsample_rates": [8, 8, 2, 2], - "upsample_kernel_sizes": [16, 16, 4, 4], - "upsample_initial_channel": 512, - "resblock_kernel_sizes": [3, 7, 11], - "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "resblock_initial_channel": 256, - "segment_size": 8192, - "num_mels": 80, - "num_freq": 1025, - "n_fft": 1024, - "hop_size": 256, - "win_size": 1024, - "sampling_rate": 22050, - "fmin": 0, - "fmax": 8000, - "fmax_loss": None, - "num_workers": 4, - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:54321", - "world_size": 1, - }, -} - -# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf -v2 = { - "resblock": "1", - "num_gpus": 0, - "batch_size": 16, - "learning_rate": 0.0002, - "adam_b1": 0.8, - "adam_b2": 0.99, - "lr_decay": 0.999, - "seed": 1234, - "upsample_rates": [8, 8, 2, 2], - "upsample_kernel_sizes": [16, 16, 4, 4], - "upsample_initial_channel": 128, - "resblock_kernel_sizes": [3, 7, 11], - "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "resblock_initial_channel": 64, - "segment_size": 8192, - "num_mels": 80, - "num_freq": 1025, - "n_fft": 1024, - "hop_size": 256, - "win_size": 1024, - "sampling_rate": 22050, - "fmin": 0, - "fmax": 8000, - "fmax_loss": None, - "num_workers": 4, - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:54321", - "world_size": 1, - }, -} - -# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1 -v3 = { - "resblock": "2", - "num_gpus": 0, - "batch_size": 16, - "learning_rate": 0.0002, - "adam_b1": 0.8, - "adam_b2": 0.99, - "lr_decay": 0.999, - "seed": 1234, - "upsample_rates": [8, 8, 4], - "upsample_kernel_sizes": [16, 16, 8], - "upsample_initial_channel": 256, - "resblock_kernel_sizes": [3, 5, 7], - "resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]], - "resblock_initial_channel": 128, - "segment_size": 8192, - "num_mels": 80, - "num_freq": 1025, - "n_fft": 1024, - "hop_size": 256, - "win_size": 1024, - "sampling_rate": 22050, - "fmin": 0, - "fmax": 8000, - "fmax_loss": None, - "num_workers": 4, - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:54321", - "world_size": 1, - }, -} diff --git a/egs/ljspeech/TTS/matcha/hifigan/denoiser.py b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py deleted file mode 100644 index b9aea61b8..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/denoiser.py +++ /dev/null @@ -1,71 +0,0 @@ -# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py - -"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" -import torch - - -class Denoiser(torch.nn.Module): - """Removes model bias from audio produced with waveglow""" - - def __init__( - self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros" - ): - super().__init__() - self.filter_length = filter_length - self.hop_length = int(filter_length / n_overlap) - self.win_length = win_length - - dtype, device = ( - next(vocoder.parameters()).dtype, - next(vocoder.parameters()).device, - ) - self.device = device - if mode == "zeros": - mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) - elif mode == "normal": - mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) - else: - raise Exception(f"Mode {mode} if not supported") - - def stft_fn(audio, n_fft, hop_length, win_length, window): - spec = torch.stft( - audio, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - return_complex=True, - ) - spec = torch.view_as_real(spec) - return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2( - spec[..., -1], spec[..., 0] - ) - - self.stft = lambda x: stft_fn( - audio=x, - n_fft=self.filter_length, - hop_length=self.hop_length, - win_length=self.win_length, - window=torch.hann_window(self.win_length, device=device), - ) - self.istft = lambda x, y: torch.istft( - torch.complex(x * torch.cos(y), x * torch.sin(y)), - n_fft=self.filter_length, - hop_length=self.hop_length, - win_length=self.win_length, - window=torch.hann_window(self.win_length, device=device), - ) - - with torch.no_grad(): - bias_audio = vocoder(mel_input).float().squeeze(0) - bias_spec, _ = self.stft(bias_audio) - - self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) - - @torch.inference_mode() - def forward(self, audio, strength=0.0005): - audio_spec, audio_angles = self.stft(audio) - audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength - audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) - audio_denoised = self.istft(audio_spec_denoised, audio_angles) - return audio_denoised diff --git a/egs/ljspeech/TTS/matcha/hifigan/env.py b/egs/ljspeech/TTS/matcha/hifigan/env.py deleted file mode 100644 index 9ea4f948a..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/env.py +++ /dev/null @@ -1,17 +0,0 @@ -""" from https://github.com/jik876/hifi-gan """ - -import os -import shutil - - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.__dict__ = self - - -def build_env(config, config_name, path): - t_path = os.path.join(path, config_name) - if config != t_path: - os.makedirs(path, exist_ok=True) - shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/egs/ljspeech/TTS/matcha/hifigan/meldataset.py b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py deleted file mode 100644 index 6eb15a326..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/meldataset.py +++ /dev/null @@ -1,245 +0,0 @@ -""" from https://github.com/jik876/hifi-gan """ - -import math -import os -import random - -import numpy as np -import torch -import torch.utils.data -from librosa.filters import mel as librosa_mel_fn -from librosa.util import normalize -from scipy.io.wavfile import read - -MAX_WAV_VALUE = 32768.0 - - -def load_wav(full_path): - sampling_rate, data = read(full_path) - return data, sampling_rate - - -def dynamic_range_compression(x, C=1, clip_val=1e-5): - return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) - - -def dynamic_range_decompression(x, C=1): - return np.exp(x) / C - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - -mel_basis = {} -hann_window = {} - - -def mel_spectrogram( - y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False -): - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window # pylint: disable=global-statement - if fmax not in mel_basis: - mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax) + "_" + str(y.device)] = ( - torch.from_numpy(mel).float().to(y.device) - ) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - - spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec - - -def get_dataset_filelist(a): - with open(a.input_training_file, encoding="utf-8") as fi: - training_files = [ - os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") - for x in fi.read().split("\n") - if len(x) > 0 - ] - - with open(a.input_validation_file, encoding="utf-8") as fi: - validation_files = [ - os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") - for x in fi.read().split("\n") - if len(x) > 0 - ] - return training_files, validation_files - - -class MelDataset(torch.utils.data.Dataset): - def __init__( - self, - training_files, - segment_size, - n_fft, - num_mels, - hop_size, - win_size, - sampling_rate, - fmin, - fmax, - split=True, - shuffle=True, - n_cache_reuse=1, - device=None, - fmax_loss=None, - fine_tuning=False, - base_mels_path=None, - ): - self.audio_files = training_files - random.seed(1234) - if shuffle: - random.shuffle(self.audio_files) - self.segment_size = segment_size - self.sampling_rate = sampling_rate - self.split = split - self.n_fft = n_fft - self.num_mels = num_mels - self.hop_size = hop_size - self.win_size = win_size - self.fmin = fmin - self.fmax = fmax - self.fmax_loss = fmax_loss - self.cached_wav = None - self.n_cache_reuse = n_cache_reuse - self._cache_ref_count = 0 - self.device = device - self.fine_tuning = fine_tuning - self.base_mels_path = base_mels_path - - def __getitem__(self, index): - filename = self.audio_files[index] - if self._cache_ref_count == 0: - audio, sampling_rate = load_wav(filename) - audio = audio / MAX_WAV_VALUE - if not self.fine_tuning: - audio = normalize(audio) * 0.95 - self.cached_wav = audio - if sampling_rate != self.sampling_rate: - raise ValueError( - f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR" - ) - self._cache_ref_count = self.n_cache_reuse - else: - audio = self.cached_wav - self._cache_ref_count -= 1 - - audio = torch.FloatTensor(audio) - audio = audio.unsqueeze(0) - - if not self.fine_tuning: - if self.split: - if audio.size(1) >= self.segment_size: - max_audio_start = audio.size(1) - self.segment_size - audio_start = random.randint(0, max_audio_start) - audio = audio[:, audio_start : audio_start + self.segment_size] - else: - audio = torch.nn.functional.pad( - audio, (0, self.segment_size - audio.size(1)), "constant" - ) - - mel = mel_spectrogram( - audio, - self.n_fft, - self.num_mels, - self.sampling_rate, - self.hop_size, - self.win_size, - self.fmin, - self.fmax, - center=False, - ) - else: - mel = np.load( - os.path.join( - self.base_mels_path, - os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", - ) - ) - mel = torch.from_numpy(mel) - - if len(mel.shape) < 3: - mel = mel.unsqueeze(0) - - if self.split: - frames_per_seg = math.ceil(self.segment_size / self.hop_size) - - if audio.size(1) >= self.segment_size: - mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) - mel = mel[:, :, mel_start : mel_start + frames_per_seg] - audio = audio[ - :, - mel_start - * self.hop_size : (mel_start + frames_per_seg) - * self.hop_size, - ] - else: - mel = torch.nn.functional.pad( - mel, (0, frames_per_seg - mel.size(2)), "constant" - ) - audio = torch.nn.functional.pad( - audio, (0, self.segment_size - audio.size(1)), "constant" - ) - - mel_loss = mel_spectrogram( - audio, - self.n_fft, - self.num_mels, - self.sampling_rate, - self.hop_size, - self.win_size, - self.fmin, - self.fmax_loss, - center=False, - ) - - return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) - - def __len__(self): - return len(self.audio_files) diff --git a/egs/ljspeech/TTS/matcha/hifigan/models.py b/egs/ljspeech/TTS/matcha/hifigan/models.py deleted file mode 100644 index e6da20610..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/models.py +++ /dev/null @@ -1,406 +0,0 @@ -""" from https://github.com/jik876/hifi-gan """ - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d -from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm - -from .xutils import get_padding, init_weights - -LRELU_SLOPE = 0.1 - - -class ResBlock1(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super().__init__() - self.h = h - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class ResBlock2(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): - super().__init__() - self.h = h - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - self.convs.apply(init_weights) - - def forward(self, x): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - -class Generator(torch.nn.Module): - def __init__(self, h): - super().__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm( - Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) - ) - resblock = ResBlock1 if h.resblock == "1" else ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) - ): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - print("Removing weight norm...") - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) - - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super().__init__() - self.period = period - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f( - Conv2d( - 1, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(5, 1), 0), - ) - ), - norm_f( - Conv2d( - 32, - 128, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(5, 1), 0), - ) - ), - norm_f( - Conv2d( - 128, - 512, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(5, 1), 0), - ) - ), - norm_f( - Conv2d( - 512, - 1024, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(5, 1), 0), - ) - ), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), - ] - ) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self): - super().__init__() - self.discriminators = nn.ModuleList( - [ - DiscriminatorP(2), - DiscriminatorP(3), - DiscriminatorP(5), - DiscriminatorP(7), - DiscriminatorP(11), - ] - ) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for _, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super().__init__() - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 128, 15, 1, padding=7)), - norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), - norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), - norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiScaleDiscriminator(torch.nn.Module): - def __init__(self): - super().__init__() - self.discriminators = nn.ModuleList( - [ - DiscriminatorS(use_spectral_norm=True), - DiscriminatorS(), - DiscriminatorS(), - ] - ) - self.meanpools = nn.ModuleList( - [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] - ) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - if i != 0: - y = self.meanpools[i - 1](y) - y_hat = self.meanpools[i - 1](y_hat) - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -def feature_loss(fmap_r, fmap_g): - loss = 0 - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - loss += torch.mean(torch.abs(rl - gl)) - - return loss * 2 - - -def discriminator_loss(disc_real_outputs, disc_generated_outputs): - loss = 0 - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - r_loss = torch.mean((1 - dr) ** 2) - g_loss = torch.mean(dg**2) - loss += r_loss + g_loss - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) - - return loss, r_losses, g_losses - - -def generator_loss(disc_outputs): - loss = 0 - gen_losses = [] - for dg in disc_outputs: - l = torch.mean((1 - dg) ** 2) - gen_losses.append(l) - loss += l - - return loss, gen_losses diff --git a/egs/ljspeech/TTS/matcha/hifigan/xutils.py b/egs/ljspeech/TTS/matcha/hifigan/xutils.py deleted file mode 100644 index eefadcb7a..000000000 --- a/egs/ljspeech/TTS/matcha/hifigan/xutils.py +++ /dev/null @@ -1,60 +0,0 @@ -""" from https://github.com/jik876/hifi-gan """ - -import glob -import os - -import matplotlib -import torch -from torch.nn.utils import weight_norm - -matplotlib.use("Agg") -import matplotlib.pylab as plt - - -def plot_spectrogram(spectrogram): - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - - fig.canvas.draw() - plt.close() - - return fig - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def apply_weight_norm(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - weight_norm(m) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - print(f"Loading '{filepath}'") - checkpoint_dict = torch.load(filepath, map_location=device) - print("Complete.") - return checkpoint_dict - - -def save_checkpoint(filepath, obj): - print(f"Saving checkpoint to {filepath}") - torch.save(obj, filepath) - print("Complete.") - - -def scan_checkpoint(cp_dir, prefix): - pattern = os.path.join(cp_dir, prefix + "????????") - cp_list = glob.glob(pattern) - if len(cp_list) == 0: - return None - return sorted(cp_list)[-1] diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py deleted file mode 100755 index 0b221d5c5..000000000 --- a/egs/ljspeech/TTS/matcha/infer.py +++ /dev/null @@ -1,328 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import argparse -import datetime as dt -import json -import logging -from pathlib import Path - -import soundfile as sf -import torch -import torch.nn as nn -from hifigan.config import v1, v2, v3 -from hifigan.denoiser import Denoiser -from hifigan.models import Generator as HiFiGAN -from tokenizer import Tokenizer -from train import get_model, get_params -from tts_datamodule import LJSpeechTtsDataModule - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, setup_logger - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=4000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--vocoder", - type=Path, - default="./generator_v1", - help="Path to the vocoder", - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - # The following arguments are used for inference on single text - parser.add_argument( - "--input-text", - type=str, - required=False, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=False, - help="The filename of the wave to save the generated speech", - ) - - parser.add_argument( - "--sampling-rate", - type=int, - default=22050, - help="The sampling rate of the generated speech (default: 22050 for LJSpeech)", - ) - - return parser - - -def load_vocoder(checkpoint_path: Path) -> nn.Module: - checkpoint_path = str(checkpoint_path) - if checkpoint_path.endswith("v1"): - h = AttributeDict(v1) - elif checkpoint_path.endswith("v2"): - h = AttributeDict(v2) - elif checkpoint_path.endswith("v3"): - h = AttributeDict(v3) - else: - raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") - - hifigan = HiFiGAN(h).to("cpu") - hifigan.load_state_dict( - torch.load(checkpoint_path, map_location="cpu")["generator"] - ) - _ = hifigan.eval() - hifigan.remove_weight_norm() - return hifigan - - -def to_waveform( - mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module -) -> torch.Tensor: - audio = vocoder(mel).clamp(-1, 1) - audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() - return audio.squeeze() - - -def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: - x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) - x = torch.tensor(x, dtype=torch.long, device=device) - x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) - return {"x_orig": text, "x": x, "x_lengths": x_lengths} - - -def synthesize( - model: nn.Module, - tokenizer: Tokenizer, - n_timesteps: int, - text: str, - length_scale: float, - temperature: float, - device: str = "cpu", - spks=None, -) -> dict: - text_processed = process_text(text=text, tokenizer=tokenizer, device=device) - start_t = dt.datetime.now() - output = model.synthesise( - text_processed["x"], - text_processed["x_lengths"], - n_timesteps=n_timesteps, - temperature=temperature, - spks=spks, - length_scale=length_scale, - ) - # merge everything to one dict - output.update({"start_t": start_t, **text_processed}) - return output - - -def infer_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - vocoder: nn.Module, - denoiser: nn.Module, - tokenizer: Tokenizer, -) -> None: - """Decode dataset. - The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - tokenizer: - Used to convert text to phonemes. - """ - - device = next(model.parameters()).device - num_cuts = 0 - log_interval = 5 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - for batch_idx, batch in enumerate(dl): - batch_size = len(batch["tokens"]) - - texts = [c.supervisions[0].normalized_text for c in batch["cut"]] - - audio = batch["audio"] - audio_lens = batch["audio_lens"].tolist() - cut_ids = [cut.id for cut in batch["cut"]] - - for i in range(batch_size): - output = synthesize( - model=model, - tokenizer=tokenizer, - n_timesteps=params.n_timesteps, - text=texts[i], - length_scale=params.length_scale, - temperature=params.temperature, - device=device, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write( - file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", - data=output["waveform"], - samplerate=params.data_args.sampling_rate, - subtype="PCM_16", - ) - sf.write( - file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", - data=audio[i].numpy(), - samplerate=params.data_args.sampling_rate, - subtype="PCM_16", - ) - - num_cuts += batch_size - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - -@torch.inference_mode() -def main(): - parser = get_parser() - LJSpeechTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.suffix = f"epoch-{params.epoch}" - - params.res_dir = params.exp_dir / "infer" / params.suffix - params.save_wav_dir = params.res_dir / "wav" - params.save_wav_dir.mkdir(parents=True, exist_ok=True) - - setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") - logging.info("Infer started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - - # Number of ODE Solver steps - params.n_timesteps = 2 - - # Changes to the speaking rate - params.length_scale = 1.0 - - # Sampling temperature - params.temperature = 0.667 - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - model.eval() - - # we need cut ids to organize tts results. - args.return_cuts = True - ljspeech = LJSpeechTtsDataModule(args) - - test_cuts = ljspeech.test_cuts() - test_dl = ljspeech.test_dataloaders(test_cuts) - - if not Path(params.vocoder).is_file(): - raise ValueError(f"{params.vocoder} does not exist") - - vocoder = load_vocoder(params.vocoder) - vocoder.to(device) - - denoiser = Denoiser(vocoder, mode="zeros") - denoiser.to(device) - - if params.input_text is not None and params.output_wav is not None: - logging.info("Synthesizing a single text") - output = synthesize( - model=model, - tokenizer=tokenizer, - n_timesteps=params.n_timesteps, - text=params.input_text, - length_scale=params.length_scale, - temperature=params.temperature, - device=device, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write( - file=params.output_wav, - data=output["waveform"], - samplerate=params.sampling_rate, - subtype="PCM_16", - ) - else: - logging.info("Decoding the test set") - infer_dataset( - dl=test_dl, - params=params, - model=model, - vocoder=vocoder, - denoiser=denoiser, - tokenizer=tokenizer, - ) - - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/matcha/model.py b/egs/ljspeech/TTS/matcha/model.py deleted file mode 100644 index 6539ffc24..000000000 --- a/egs/ljspeech/TTS/matcha/model.py +++ /dev/null @@ -1,97 +0,0 @@ -# This file is copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/model.py -""" from https://github.com/jaywalnut310/glow-tts """ - -import numpy as np -import torch - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def fix_len_compatibility(length, num_downsamplings_in_unet=2): - factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) - length = (length / factor).ceil() * factor - if not torch.onnx.is_in_onnx_export(): - return length.int().item() - else: - return length - - -def convert_pad_shape(pad_shape): - inverted_shape = pad_shape[::-1] - pad_shape = [item for sublist in inverted_shape for item in sublist] - return pad_shape - - -def generate_path(duration, mask): - device = duration.device - - b, t_x, t_y = mask.shape - cum_duration = torch.cumsum(duration, 1) - path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = ( - path - - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[ - :, :-1 - ] - ) - path = path * mask - return path - - -def duration_loss(logw, logw_, lengths): - loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) - return loss - - -def normalize(data, mu, std): - if not isinstance(mu, (float, int)): - if isinstance(mu, list): - mu = torch.tensor(mu, dtype=data.dtype, device=data.device) - elif isinstance(mu, torch.Tensor): - mu = mu.to(data.device) - elif isinstance(mu, np.ndarray): - mu = torch.from_numpy(mu).to(data.device) - mu = mu.unsqueeze(-1) - - if not isinstance(std, (float, int)): - if isinstance(std, list): - std = torch.tensor(std, dtype=data.dtype, device=data.device) - elif isinstance(std, torch.Tensor): - std = std.to(data.device) - elif isinstance(std, np.ndarray): - std = torch.from_numpy(std).to(data.device) - std = std.unsqueeze(-1) - - return (data - mu) / std - - -def denormalize(data, mu, std): - if not isinstance(mu, float): - if isinstance(mu, list): - mu = torch.tensor(mu, dtype=data.dtype, device=data.device) - elif isinstance(mu, torch.Tensor): - mu = mu.to(data.device) - elif isinstance(mu, np.ndarray): - mu = torch.from_numpy(mu).to(data.device) - mu = mu.unsqueeze(-1) - - if not isinstance(std, float): - if isinstance(std, list): - std = torch.tensor(std, dtype=data.dtype, device=data.device) - elif isinstance(std, torch.Tensor): - std = std.to(data.device) - elif isinstance(std, np.ndarray): - std = torch.from_numpy(std).to(data.device) - std = std.unsqueeze(-1) - - return data * std + mu diff --git a/egs/ljspeech/TTS/matcha/models/README.md b/egs/ljspeech/TTS/matcha/models/README.md deleted file mode 100644 index 1099ef3c8..000000000 --- a/egs/ljspeech/TTS/matcha/models/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Introduction -Files in this folder are copied from -https://github.com/shivammehta25/Matcha-TTS/tree/main/matcha/models diff --git a/egs/ljspeech/TTS/matcha/models/__init__.py b/egs/ljspeech/TTS/matcha/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/ljspeech/TTS/matcha/models/components/__init__.py b/egs/ljspeech/TTS/matcha/models/components/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py deleted file mode 100644 index 102d87713..000000000 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ /dev/null @@ -1,459 +0,0 @@ -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from conformer import ConformerBlock -from diffusers.models.activations import get_activation -from einops import pack, rearrange, repeat -from models.components.transformer import BasicTransformerBlock - - -class SinusoidalPosEmb(torch.nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" - - def forward(self, x, scale=1000): - if x.ndim < 1: - x = x.unsqueeze(0) - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -class Block1D(torch.nn.Module): - def __init__(self, dim, dim_out, groups=8): - super().__init__() - self.block = torch.nn.Sequential( - torch.nn.Conv1d(dim, dim_out, 3, padding=1), - torch.nn.GroupNorm(groups, dim_out), - nn.Mish(), - ) - - def forward(self, x, mask): - output = self.block(x * mask) - return output * mask - - -class ResnetBlock1D(torch.nn.Module): - def __init__(self, dim, dim_out, time_emb_dim, groups=8): - super().__init__() - self.mlp = torch.nn.Sequential( - nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out) - ) - - self.block1 = Block1D(dim, dim_out, groups=groups) - self.block2 = Block1D(dim_out, dim_out, groups=groups) - - self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) - - def forward(self, x, mask, time_emb): - h = self.block1(x, mask) - h += self.mlp(time_emb).unsqueeze(-1) - h = self.block2(h, mask) - output = h + self.res_conv(x * mask) - return output - - -class Downsample1D(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - - if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) - else: - self.cond_proj = None - - self.act = get_activation(act_fn) - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - - if post_act_fn is None: - self.post_act = None - else: - self.post_act = get_activation(post_act_fn) - - def forward(self, sample, condition=None): - if condition is not None: - sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - - if self.post_act is not None: - sample = self.post_act(sample) - return sample - - -class Upsample1D(nn.Module): - """A 1D upsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - use_conv_transpose (`bool`, default `False`): - option to use a convolution transpose. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=True, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, inputs): - assert inputs.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(inputs) - - outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") - - if self.use_conv: - outputs = self.conv(outputs) - - return outputs - - -class ConformerWrapper(ConformerBlock): - def __init__( # pylint: disable=useless-super-delegation - self, - *, - dim, - dim_head=64, - heads=8, - ff_mult=4, - conv_expansion_factor=2, - conv_kernel_size=31, - attn_dropout=0, - ff_dropout=0, - conv_dropout=0, - conv_causal=False, - ): - super().__init__( - dim=dim, - dim_head=dim_head, - heads=heads, - ff_mult=ff_mult, - conv_expansion_factor=conv_expansion_factor, - conv_kernel_size=conv_kernel_size, - attn_dropout=attn_dropout, - ff_dropout=ff_dropout, - conv_dropout=conv_dropout, - conv_causal=conv_causal, - ) - - def forward( - self, - hidden_states, - attention_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=None, - ): - return super().forward(x=hidden_states, mask=attention_mask.bool()) - - -class Decoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - channels=(256, 256), - dropout=0.05, - attention_head_dim=64, - n_blocks=1, - num_mid_blocks=2, - num_heads=4, - act_fn="snake", - down_block_type="transformer", - mid_block_type="transformer", - up_block_type="transformer", - ): - super().__init__() - channels = tuple(channels) - self.in_channels = in_channels - self.out_channels = out_channels - - self.time_embeddings = SinusoidalPosEmb(in_channels) - time_embed_dim = channels[0] * 4 - self.time_mlp = TimestepEmbedding( - in_channels=in_channels, - time_embed_dim=time_embed_dim, - act_fn="silu", - ) - - self.down_blocks = nn.ModuleList([]) - self.mid_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - output_channel = in_channels - for i in range(len(channels)): # pylint: disable=consider-using-enumerate - input_channel = output_channel - output_channel = channels[i] - is_last = i == len(channels) - 1 - resnet = ResnetBlock1D( - dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim - ) - transformer_blocks = nn.ModuleList( - [ - self.get_block( - down_block_type, - output_channel, - attention_head_dim, - num_heads, - dropout, - act_fn, - ) - for _ in range(n_blocks) - ] - ) - downsample = ( - Downsample1D(output_channel) - if not is_last - else nn.Conv1d(output_channel, output_channel, 3, padding=1) - ) - - self.down_blocks.append( - nn.ModuleList([resnet, transformer_blocks, downsample]) - ) - - for i in range(num_mid_blocks): - input_channel = channels[-1] - out_channels = channels[-1] - - resnet = ResnetBlock1D( - dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim - ) - - transformer_blocks = nn.ModuleList( - [ - self.get_block( - mid_block_type, - output_channel, - attention_head_dim, - num_heads, - dropout, - act_fn, - ) - for _ in range(n_blocks) - ] - ) - - self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) - - channels = channels[::-1] + (channels[0],) - for i in range(len(channels) - 1): - input_channel = channels[i] - output_channel = channels[i + 1] - is_last = i == len(channels) - 2 - - resnet = ResnetBlock1D( - dim=2 * input_channel, - dim_out=output_channel, - time_emb_dim=time_embed_dim, - ) - transformer_blocks = nn.ModuleList( - [ - self.get_block( - up_block_type, - output_channel, - attention_head_dim, - num_heads, - dropout, - act_fn, - ) - for _ in range(n_blocks) - ] - ) - upsample = ( - Upsample1D(output_channel, use_conv_transpose=True) - if not is_last - else nn.Conv1d(output_channel, output_channel, 3, padding=1) - ) - - self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) - - self.final_block = Block1D(channels[-1], channels[-1]) - self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) - - self.initialize_weights() - # nn.init.normal_(self.final_proj.weight) - - @staticmethod - def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): - if block_type == "conformer": - block = ConformerWrapper( - dim=dim, - dim_head=attention_head_dim, - heads=num_heads, - ff_mult=1, - conv_expansion_factor=2, - ff_dropout=dropout, - attn_dropout=dropout, - conv_dropout=dropout, - conv_kernel_size=31, - ) - elif block_type == "transformer": - block = BasicTransformerBlock( - dim=dim, - num_attention_heads=num_heads, - attention_head_dim=attention_head_dim, - dropout=dropout, - activation_fn=act_fn, - ) - else: - raise ValueError(f"Unknown block type {block_type}") - - return block - - def initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv1d): - nn.init.kaiming_normal_(m.weight, nonlinearity="relu") - - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - elif isinstance(m, nn.GroupNorm): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - elif isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight, nonlinearity="relu") - - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x, mask, mu, t, spks=None, cond=None): - """Forward pass of the UNet1DConditional model. - - Args: - x (torch.Tensor): shape (batch_size, in_channels, time) - mask (_type_): shape (batch_size, 1, time) - t (_type_): shape (batch_size) - spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. - cond (_type_, optional): placeholder for future use. Defaults to None. - - Raises: - ValueError: _description_ - ValueError: _description_ - - Returns: - _type_: _description_ - """ - - t = self.time_embeddings(t) - t = self.time_mlp(t) - - x = pack([x, mu], "b * t")[0] - - if spks is not None: - spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) - x = pack([x, spks], "b * t")[0] - - hiddens = [] - masks = [mask] - for resnet, transformer_blocks, downsample in self.down_blocks: - mask_down = masks[-1] - x = resnet(x, mask_down, t) - x = rearrange(x, "b c t -> b t c") - mask_down = rearrange(mask_down, "b 1 t -> b t") - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=mask_down, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t") - mask_down = rearrange(mask_down, "b t -> b 1 t") - hiddens.append(x) # Save hidden states for skip connections - x = downsample(x * mask_down) - masks.append(mask_down[:, :, ::2]) - - masks = masks[:-1] - mask_mid = masks[-1] - - for resnet, transformer_blocks in self.mid_blocks: - x = resnet(x, mask_mid, t) - x = rearrange(x, "b c t -> b t c") - mask_mid = rearrange(mask_mid, "b 1 t -> b t") - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=mask_mid, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t") - mask_mid = rearrange(mask_mid, "b t -> b 1 t") - - for resnet, transformer_blocks, upsample in self.up_blocks: - mask_up = masks.pop() - x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) - x = rearrange(x, "b c t -> b t c") - mask_up = rearrange(mask_up, "b 1 t -> b t") - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=mask_up, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t") - mask_up = rearrange(mask_up, "b t -> b 1 t") - x = upsample(x * mask_up) - - x = self.final_block(x, mask_up) - output = self.final_proj(x * mask_up) - - return output * mask diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py deleted file mode 100644 index eb795ef32..000000000 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ /dev/null @@ -1,140 +0,0 @@ -from abc import ABC - -import torch -import torch.nn.functional as F -from models.components.decoder import Decoder - - -class BASECFM(torch.nn.Module, ABC): - def __init__( - self, - n_feats, - cfm_params, - n_spks=1, - spk_emb_dim=128, - ): - super().__init__() - self.n_feats = n_feats - self.n_spks = n_spks - self.spk_emb_dim = spk_emb_dim - self.solver = cfm_params.solver - if hasattr(cfm_params, "sigma_min"): - self.sigma_min = cfm_params.sigma_min - else: - self.sigma_min = 1e-4 - - self.estimator = None - - @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): - """Forward diffusion - - Args: - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - n_timesteps (int): number of diffusion steps - temperature (float, optional): temperature for scaling noise. Defaults to 1.0. - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - - Returns: - sample: generated mel-spectrogram - shape: (batch_size, n_feats, mel_timesteps) - """ - z = torch.randn_like(mu) * temperature - t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) - return self.solve_euler( - z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond - ) - - def solve_euler(self, x, t_span, mu, mask, spks, cond): - """ - Fixed euler solver for ODEs. - Args: - x (torch.Tensor): random noise - t_span (torch.Tensor): n_timesteps interpolated - shape: (n_timesteps + 1,) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - """ - t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - - # I am storing this because I can later plot it by putting a debugger here and saving it to a file - # Or in future might add like a return_all_steps flag - sol = [] - - for step in range(1, len(t_span)): - dphi_dt = self.estimator(x, mask, mu, t, spks, cond) - - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def compute_loss(self, x1, mask, mu, spks=None, cond=None): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): target mask - shape: (batch_size, 1, mel_timesteps) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - spks (torch.Tensor, optional): speaker embedding. Defaults to None. - shape: (batch_size, spk_emb_dim) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_feats, mel_timesteps) - """ - b, _, t = mu.shape - - # random timestep - t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - - loss = F.mse_loss( - self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum" - ) / (torch.sum(mask) * u.shape[1]) - return loss, y - - -class CFM(BASECFM): - def __init__( - self, - in_channels, - out_channel, - cfm_params, - decoder_params, - n_spks=1, - spk_emb_dim=64, - ): - super().__init__( - n_feats=in_channels, - cfm_params=cfm_params, - n_spks=n_spks, - spk_emb_dim=spk_emb_dim, - ) - - in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) - # Just change the architecture of the estimator here - self.estimator = Decoder( - in_channels=in_channels, out_channels=out_channel, **decoder_params - ) diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py deleted file mode 100644 index 364ff1938..000000000 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ /dev/null @@ -1,447 +0,0 @@ -""" from https://github.com/jaywalnut310/glow-tts """ - -import math - -import torch -import torch.nn as nn -from einops import rearrange -from model import sequence_mask - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-4): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = torch.nn.Parameter(torch.ones(channels)) - self.beta = torch.nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - n_dims = len(x.shape) - mean = torch.mean(x, 1, keepdim=True) - variance = torch.mean((x - mean) ** 2, 1, keepdim=True) - - x = (x - mean) * torch.rsqrt(variance + self.eps) - - shape = [1, -1] + [1] * (n_dims - 2) - x = x * self.gamma.view(*shape) + self.beta.view(*shape) - return x - - -class ConvReluNorm(nn.Module): - def __init__( - self, - in_channels, - hidden_channels, - out_channels, - kernel_size, - n_layers, - p_dropout, - ): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.conv_layers = torch.nn.ModuleList() - self.norm_layers = torch.nn.ModuleList() - self.conv_layers.append( - torch.nn.Conv1d( - in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 - ) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = torch.nn.Sequential( - torch.nn.ReLU(), torch.nn.Dropout(p_dropout) - ) - for _ in range(n_layers - 1): - self.conv_layers.append( - torch.nn.Conv1d( - hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2, - ) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask - - -class DurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.p_dropout = p_dropout - - self.drop = torch.nn.Dropout(p_dropout) - self.conv_1 = torch.nn.Conv1d( - in_channels, filter_channels, kernel_size, padding=kernel_size // 2 - ) - self.norm_1 = LayerNorm(filter_channels) - self.conv_2 = torch.nn.Conv1d( - filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 - ) - self.norm_2 = LayerNorm(filter_channels) - self.proj = torch.nn.Conv1d(filter_channels, 1, 1) - - def forward(self, x, x_mask): - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - -class RotaryPositionalEmbeddings(nn.Module): - """ - ## RoPE module - - Rotary encoding transforms pairs of features by rotating in the 2D plane. - That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. - Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it - by an angle depending on the position of the token. - """ - - def __init__(self, d: int, base: int = 10_000): - r""" - * `d` is the number of features $d$ - * `base` is the constant used for calculating $\Theta$ - """ - super().__init__() - - self.base = base - self.d = int(d) - self.cos_cached = None - self.sin_cached = None - - def _build_cache(self, x: torch.Tensor): - r""" - Cache $\cos$ and $\sin$ values - """ - # Return if cache is already built - if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: - return - - # Get sequence length - seq_len = x.shape[0] - - # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to( - x.device - ) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.einsum("n,d->nd", seq_idx, theta) - - # Concatenate so that for row $m$ we have - # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ - idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) - - # Cache them - self.cos_cached = idx_theta2.cos()[:, None, None, :] - self.sin_cached = idx_theta2.sin()[:, None, None, :] - - def _neg_half(self, x: torch.Tensor): - # $\frac{d}{2}$ - d_2 = self.d // 2 - - # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) - - def forward(self, x: torch.Tensor): - """ - * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` - """ - # Cache $\cos$ and $\sin$ values - x = rearrange(x, "b h t d -> t b h d") - - self._build_cache(x) - - # Split the features, we can choose to apply rotary embeddings only to a partial set of features. - x_rope, x_pass = x[..., : self.d], x[..., self.d :] - - # Calculate - # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - neg_half_x = self._neg_half(x_rope) - - x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + ( - neg_half_x * self.sin_cached[: x.shape[0]] - ) - - return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") - - -class MultiHeadAttention(nn.Module): - def __init__( - self, - channels, - out_channels, - n_heads, - heads_share=True, - p_dropout=0.0, - proximal_bias=False, - proximal_init=False, - ): - super().__init__() - assert channels % n_heads == 0 - - self.channels = channels - self.out_channels = out_channels - self.n_heads = n_heads - self.heads_share = heads_share - self.proximal_bias = proximal_bias - self.p_dropout = p_dropout - self.attn = None - - self.k_channels = channels // n_heads - self.conv_q = torch.nn.Conv1d(channels, channels, 1) - self.conv_k = torch.nn.Conv1d(channels, channels, 1) - self.conv_v = torch.nn.Conv1d(channels, channels, 1) - - # from https://nn.labml.ai/transformers/rope/index.html - self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) - self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) - - self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) - self.drop = torch.nn.Dropout(p_dropout) - - torch.nn.init.xavier_uniform_(self.conv_q.weight) - torch.nn.init.xavier_uniform_(self.conv_k.weight) - if proximal_init: - self.conv_k.weight.data.copy_(self.conv_q.weight.data) - self.conv_k.bias.data.copy_(self.conv_q.bias.data) - torch.nn.init.xavier_uniform_(self.conv_v.weight) - - def forward(self, x, c, attn_mask=None): - q = self.conv_q(x) - k = self.conv_k(c) - v = self.conv_v(c) - - x, self.attn = self.attention(q, k, v, mask=attn_mask) - - x = self.conv_o(x) - return x - - def attention(self, query, key, value, mask=None): - b, d, t_s, t_t = (*key.size(), query.size(2)) - query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) - key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) - value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) - - query = self.query_rotary_pe(query) - key = self.key_rotary_pe(key) - - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) - - if self.proximal_bias: - assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to( - device=scores.device, dtype=scores.dtype - ) - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e4) - p_attn = torch.nn.functional.softmax(scores, dim=-1) - p_attn = self.drop(p_attn) - output = torch.matmul(p_attn, value) - output = output.transpose(2, 3).contiguous().view(b, d, t_t) - return output, p_attn - - @staticmethod - def _attention_bias_proximal(length): - r = torch.arange(length, dtype=torch.float32) - diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) - - -class FFN(nn.Module): - def __init__( - self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0 - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.conv_1 = torch.nn.Conv1d( - in_channels, filter_channels, kernel_size, padding=kernel_size // 2 - ) - self.conv_2 = torch.nn.Conv1d( - filter_channels, out_channels, kernel_size, padding=kernel_size // 2 - ) - self.drop = torch.nn.Dropout(p_dropout) - - def forward(self, x, x_mask): - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - return x * x_mask - - -class Encoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - **kwargs, - ): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.drop = torch.nn.Dropout(p_dropout) - self.attn_layers = torch.nn.ModuleList() - self.norm_layers_1 = torch.nn.ModuleList() - self.ffn_layers = torch.nn.ModuleList() - self.norm_layers_2 = torch.nn.ModuleList() - for _ in range(self.n_layers): - self.attn_layers.append( - MultiHeadAttention( - hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout - ) - ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask): - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - for i in range(self.n_layers): - x = x * x_mask - y = self.attn_layers[i](x, x, attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x - - -class TextEncoder(nn.Module): - def __init__( - self, - encoder_type, - encoder_params, - duration_predictor_params, - n_vocab, - n_spks=1, - spk_emb_dim=128, - ): - super().__init__() - self.encoder_type = encoder_type - self.n_vocab = n_vocab - self.n_feats = encoder_params.n_feats - self.n_channels = encoder_params.n_channels - self.spk_emb_dim = spk_emb_dim - self.n_spks = n_spks - - self.emb = torch.nn.Embedding(n_vocab, self.n_channels) - torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) - - if encoder_params.prenet: - self.prenet = ConvReluNorm( - self.n_channels, - self.n_channels, - self.n_channels, - kernel_size=5, - n_layers=3, - p_dropout=0.5, - ) - else: - self.prenet = lambda x, x_mask: x - - self.encoder = Encoder( - encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), - encoder_params.filter_channels, - encoder_params.n_heads, - encoder_params.n_layers, - encoder_params.kernel_size, - encoder_params.p_dropout, - ) - - self.proj_m = torch.nn.Conv1d( - self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1 - ) - self.proj_w = DurationPredictor( - self.n_channels + (spk_emb_dim if n_spks > 1 else 0), - duration_predictor_params.filter_channels_dp, - duration_predictor_params.kernel_size, - duration_predictor_params.p_dropout, - ) - - def forward(self, x, x_lengths, spks=None): - """Run forward pass to the transformer based encoder and duration predictor - - Args: - x (torch.Tensor): text input - shape: (batch_size, max_text_length) - x_lengths (torch.Tensor): text input lengths - shape: (batch_size,) - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size,) - - Returns: - mu (torch.Tensor): average output of the encoder - shape: (batch_size, n_feats, max_text_length) - logw (torch.Tensor): log duration predicted by the duration predictor - shape: (batch_size, 1, max_text_length) - x_mask (torch.Tensor): mask for the text input - shape: (batch_size, 1, max_text_length) - """ - x = self.emb(x) * math.sqrt(self.n_channels) - x = torch.transpose(x, 1, -1) - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - - x = self.prenet(x, x_mask) - if self.n_spks > 1: - x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) - x = self.encoder(x, x_mask) - mu = self.proj_m(x) * x_mask - - x_dp = torch.detach(x) - logw = self.proj_w(x_dp, x_mask) - - return mu, logw, x_mask diff --git a/egs/ljspeech/TTS/matcha/models/components/transformer.py b/egs/ljspeech/TTS/matcha/models/components/transformer.py deleted file mode 100644 index a82e560bc..000000000 --- a/egs/ljspeech/TTS/matcha/models/components/transformer.py +++ /dev/null @@ -1,353 +0,0 @@ -from typing import Any, Dict, Optional - -import torch -import torch.nn as nn -from diffusers.models.attention import ( - GEGLU, - GELU, - AdaLayerNorm, - AdaLayerNormZero, - ApproximateGELU, -) -from diffusers.models.attention_processor import Attention -from diffusers.models.lora import LoRACompatibleLinear -from diffusers.utils.torch_utils import maybe_allow_in_graph - - -class SnakeBeta(nn.Module): - """ - A modified Snake function which uses separate parameters for the magnitude of the periodic components - Shape: - - Input: (B, C, T) - - Output: (B, C, T), same shape as the input - Parameters: - - alpha - trainable parameter that controls frequency - - beta - trainable parameter that controls magnitude - References: - - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: - https://arxiv.org/abs/2006.08195 - Examples: - >>> a1 = snakebeta(256) - >>> x = torch.randn(256) - >>> x = a1(x) - """ - - def __init__( - self, - in_features, - out_features, - alpha=1.0, - alpha_trainable=True, - alpha_logscale=True, - ): - """ - Initialization. - INPUT: - - in_features: shape of the input - - alpha - trainable parameter that controls frequency - - beta - trainable parameter that controls magnitude - alpha is initialized to 1 by default, higher values = higher-frequency. - beta is initialized to 1 by default, higher values = higher-magnitude. - alpha will be trained along with the rest of your model. - """ - super().__init__() - self.in_features = ( - out_features if isinstance(out_features, list) else [out_features] - ) - self.proj = LoRACompatibleLinear(in_features, out_features) - - # initialize alpha - self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros - self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) - self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) - else: # linear scale alphas initialized to ones - self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) - self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) - - self.alpha.requires_grad = alpha_trainable - self.beta.requires_grad = alpha_trainable - - self.no_div_by_zero = 0.000000001 - - def forward(self, x): - """ - Forward pass of the function. - Applies the function to the input elementwise. - SnakeBeta ∶= x + 1/b * sin^2 (xa) - """ - x = self.proj(x) - if self.alpha_logscale: - alpha = torch.exp(self.alpha) - beta = torch.exp(self.beta) - else: - alpha = self.alpha - beta = self.beta - - x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( - torch.sin(x * alpha), 2 - ) - - return x - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - elif activation_fn == "snakebeta": - act_fn = SnakeBeta(dim, inner_dim) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -@maybe_allow_in_graph -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = ( - num_embeds_ada_norm is not None - ) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = ( - num_embeds_ada_norm is not None - ) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim - if not double_self_attention - else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - # scale_qk=False, # uncomment this to not to use flash attention - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - ) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - ): - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states - if self.only_cross_attention - else None, - attention_mask=encoder_attention_mask - if self.only_cross_attention - else attention_mask, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) - if self.use_ada_layer_norm - else self.norm2(hidden_states) - ) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = ( - norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ) - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [ - self.ff(hid_slice) - for hid_slice in norm_hidden_states.chunk( - num_chunks, dim=self._chunk_dim - ) - ], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py deleted file mode 100644 index fe0a72402..000000000 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ /dev/null @@ -1,295 +0,0 @@ -import datetime as dt -import math -import random - -import monotonic_align as monotonic_align -import torch -from model import ( - denormalize, - duration_loss, - fix_len_compatibility, - generate_path, - sequence_mask, -) -from models.components.flow_matching import CFM -from models.components.text_encoder import TextEncoder - - -class MatchaTTS(torch.nn.Module): # 🍵 - def __init__( - self, - n_vocab, - n_spks, - spk_emb_dim, - n_feats, - encoder, - decoder, - cfm, - data_statistics, - out_size, - optimizer=None, - scheduler=None, - prior_loss=True, - use_precomputed_durations=False, - ): - super().__init__() - - # self.save_hyperparameters(logger=False) - - self.n_vocab = n_vocab - self.n_spks = n_spks - self.spk_emb_dim = spk_emb_dim - self.n_feats = n_feats - self.out_size = out_size - self.prior_loss = prior_loss - self.use_precomputed_durations = use_precomputed_durations - - if n_spks > 1: - self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) - - self.encoder = TextEncoder( - encoder.encoder_type, - encoder.encoder_params, - encoder.duration_predictor_params, - n_vocab, - n_spks, - spk_emb_dim, - ) - - self.decoder = CFM( - in_channels=2 * encoder.encoder_params.n_feats, - out_channel=encoder.encoder_params.n_feats, - cfm_params=cfm, - decoder_params=decoder, - n_spks=n_spks, - spk_emb_dim=spk_emb_dim, - ) - - if data_statistics is not None: - self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) - self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) - else: - self.register_buffer("mel_mean", torch.tensor(0.0)) - self.register_buffer("mel_std", torch.tensor(1.0)) - - @torch.inference_mode() - def synthesise( - self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0 - ): - """ - Generates mel-spectrogram from text. Returns: - 1. encoder outputs - 2. decoder outputs - 3. generated alignment - - Args: - x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. - shape: (batch_size, max_text_length) - x_lengths (torch.Tensor): lengths of texts in batch. - shape: (batch_size,) - n_timesteps (int): number of steps to use for reverse diffusion in decoder. - temperature (float, optional): controls variance of terminal distribution. - spks (bool, optional): speaker ids. - shape: (batch_size,) - length_scale (float, optional): controls speech pace. - Increase value to slow down generated speech and vice versa. - - Returns: - dict: { - "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), - # Average mel spectrogram generated by the encoder - "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), - # Refined mel spectrogram improved by the CFM - "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), - # Alignment map between text and mel spectrogram - "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), - # Denormalized mel spectrogram - "mel_lengths": torch.Tensor, shape: (batch_size,), - # Lengths of mel spectrograms - "rtf": float, - # Real-time factor - """ - # For RTF computation - t = dt.datetime.now() - - if self.n_spks > 1: - # Get speaker embedding - spks = self.spk_emb(spks.long()) - - # Get encoder_outputs `mu_x` and log-scaled token durations `logw` - mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) - - w = torch.exp(logw) * x_mask - w_ceil = torch.ceil(w) * length_scale - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_max_length = y_lengths.max() - y_max_length_ = fix_len_compatibility(y_max_length) - - # Using obtained durations `w` construct alignment map `attn` - y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) - attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) - attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) - - # Align encoded text and get mu_y - mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) - mu_y = mu_y.transpose(1, 2) - encoder_outputs = mu_y[:, :, :y_max_length] - - # Generate sample tracing the probability flow - decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) - decoder_outputs = decoder_outputs[:, :, :y_max_length] - - t = (dt.datetime.now() - t).total_seconds() - rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) - - return { - "encoder_outputs": encoder_outputs, - "decoder_outputs": decoder_outputs, - "attn": attn[:, :, :y_max_length], - "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), - "mel_lengths": y_lengths, - "rtf": rtf, - } - - def forward( - self, - x, - x_lengths, - y, - y_lengths, - spks=None, - out_size=None, - cond=None, - durations=None, - ): - """ - Computes 3 losses: - 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). - 2. prior loss: loss between mel-spectrogram and encoder outputs. - 3. flow matching loss: loss between mel-spectrogram and decoder outputs. - - Args: - x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. - shape: (batch_size, max_text_length) - x_lengths (torch.Tensor): lengths of texts in batch. - shape: (batch_size,) - y (torch.Tensor): batch of corresponding mel-spectrograms. - shape: (batch_size, n_feats, max_mel_length) - y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. - shape: (batch_size,) - out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. - Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. - spks (torch.Tensor, optional): speaker ids. - shape: (batch_size,) - """ - if self.n_spks > 1: - # Get speaker embedding - spks = self.spk_emb(spks) - - # Get encoder_outputs `mu_x` and log-scaled token durations `logw` - mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) - y_max_length = y.shape[-1] - - y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) - attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) - - if self.use_precomputed_durations: - attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) - else: - # Use MAS to find most likely alignment `attn` between text and mel-spectrogram - with torch.no_grad(): - const = -0.5 * math.log(2 * math.pi) * self.n_feats - factor = -0.5 * torch.ones( - mu_x.shape, dtype=mu_x.dtype, device=mu_x.device - ) - y_square = torch.matmul(factor.transpose(1, 2), y**2) - y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) - mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) - log_prior = y_square - y_mu_double + mu_square + const - - attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) - attn = attn.detach() # b, t_text, T_mel - - # Compute loss between predicted log-scaled durations and those obtained from MAS - # refered to as prior loss in the paper - logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask - dur_loss = duration_loss(logw, logw_, x_lengths) - - # Cut a small segment of mel-spectrogram in order to increase batch size - # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it - # - Do not need this hack for Matcha-TTS, but it works with it as well - if not isinstance(out_size, type(None)): - max_offset = (y_lengths - out_size).clamp(0) - offset_ranges = list( - zip([0] * max_offset.shape[0], max_offset.cpu().numpy()) - ) - out_offset = torch.LongTensor( - [ - torch.tensor(random.choice(range(start, end)) if end > start else 0) - for start, end in offset_ranges - ] - ).to(y_lengths) - attn_cut = torch.zeros( - attn.shape[0], - attn.shape[1], - out_size, - dtype=attn.dtype, - device=attn.device, - ) - y_cut = torch.zeros( - y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device - ) - - y_cut_lengths = [] - for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): - y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) - y_cut_lengths.append(y_cut_length) - cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length - y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] - attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] - - y_cut_lengths = torch.LongTensor(y_cut_lengths) - y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) - - attn = attn_cut - y = y_cut - y_mask = y_cut_mask - - # Align encoded text with mel-spectrogram and get mu_y segment - mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) - mu_y = mu_y.transpose(1, 2) - - # Compute loss of the decoder - diff_loss, _ = self.decoder.compute_loss( - x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond - ) - - if self.prior_loss: - prior_loss = torch.sum( - 0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask - ) - prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) - else: - prior_loss = 0 - - return dur_loss, prior_loss, diff_loss, attn - - def get_losses(self, batch): - x, x_lengths = batch["x"], batch["x_lengths"] - y, y_lengths = batch["y"], batch["y_lengths"] - spks = batch["spks"] - - dur_loss, prior_loss, diff_loss, *_ = self( - x=x, - x_lengths=x_lengths, - y=y, - y_lengths=y_lengths, - spks=spks, - out_size=self.out_size, - durations=batch["durations"], - ) - return { - "dur_loss": dur_loss, - "prior_loss": prior_loss, - "diff_loss": diff_loss, - } diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore deleted file mode 100644 index 3def4ae26..000000000 --- a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -build -core.c -*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py deleted file mode 100644 index f87ae1bd3..000000000 --- a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import numpy as np -import torch - -from .core import maximum_path_c - - -def maximum_path(value, mask): - """Cython optimised version. - value: [b, t_x, t_y] - mask: [b, t_x, t_y] - """ - value = value * mask - device = value.device - dtype = value.dtype - value = value.data.cpu().numpy().astype(np.float32) - path = np.zeros_like(value).astype(np.int32) - mask = mask.data.cpu().numpy() - - t_x_max = mask.sum(1)[:, 0].astype(np.int32) - t_y_max = mask.sum(2)[:, 0].astype(np.int32) - maximum_path_c(path, value, t_x_max, t_y_max) - return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx deleted file mode 100644 index 091fcc3a5..000000000 --- a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -cimport cython -cimport numpy as np - -from cython.parallel import prange - - -@cython.boundscheck(False) -@cython.wraparound(False) -cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: - cdef int x - cdef int y - cdef float v_prev - cdef float v_cur - cdef float tmp - cdef int index = t_x - 1 - - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[x, y-1] - if x == 0: - if y == 0: - v_prev = 0. - else: - v_prev = max_neg_val - else: - v_prev = value[x-1, y-1] - value[x, y] = max(v_cur, v_prev) + value[x, y] - - for y in range(t_y - 1, -1, -1): - path[index, y] = 1 - if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): - index = index - 1 - - -@cython.boundscheck(False) -@cython.wraparound(False) -cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: - cdef int b = values.shape[0] - - cdef int i - for i in prange(b, nogil=True): - maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py deleted file mode 100644 index beacf2e36..000000000 --- a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py +++ /dev/null @@ -1,30 +0,0 @@ -# Modified from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py -from Cython.Build import cythonize -from setuptools import Extension, setup -from setuptools.command.build_ext import build_ext as _build_ext - - -class build_ext(_build_ext): - """Overwrite build_ext.""" - - def finalize_options(self): - """Prevent numpy from thinking it is still in its setup process.""" - _build_ext.finalize_options(self) - __builtins__.__NUMPY_SETUP__ = False - import numpy - - self.include_dirs.append(numpy.get_include()) - - -exts = [ - Extension( - name="core", - sources=["core.pyx"], - ) -] -setup( - name="monotonic_align", - ext_modules=cythonize(exts, language_level=3), - cmdclass={"build_ext": build_ext}, -) diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py deleted file mode 100755 index 19e9b49cb..000000000 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import argparse -import datetime as dt -import logging - -import onnxruntime as ort -import soundfile as sf -import torch -from infer import load_vocoder -from tokenizer import Tokenizer - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--acoustic-model", - type=str, - required=True, - help="Path to the acoustic model", - ) - - parser.add_argument( - "--tokens", - type=str, - required=True, - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--vocoder", - type=str, - required=True, - help="Path to the vocoder", - ) - - parser.add_argument( - "--input-text", - type=str, - required=True, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=True, - help="The filename of the wave to save the generated speech", - ) - - return parser - - -class OnnxHifiGANModel: - def __init__( - self, - filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - self.model = ort.InferenceSession( - filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - for i in self.model.get_inputs(): - print(i) - - print("-----") - - for i in self.model.get_outputs(): - print(i) - - def __call__(self, x: torch.tensor): - assert x.ndim == 3, x.shape - assert x.shape[0] == 1, x.shape - - audio = self.model.run( - [self.model.get_outputs()[0].name], - { - self.model.get_inputs()[0].name: x.numpy(), - }, - )[0] - # audio: (batch_size, num_samples) - - return torch.from_numpy(audio) - - -class OnnxModel: - def __init__( - self, - filename: str, - tokens: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 2 - - self.session_opts = session_opts - self.tokenizer = Tokenizer(tokens) - self.model = ort.InferenceSession( - filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - metadata = self.model.get_modelmeta().custom_metadata_map - self.sample_rate = int(metadata["sample_rate"]) - - for i in self.model.get_inputs(): - print(i) - - print("-----") - - for i in self.model.get_outputs(): - print(i) - - def __call__(self, x: torch.tensor): - assert x.ndim == 2, x.shape - assert x.shape[0] == 1, x.shape - - x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - print("x_lengths", x_lengths) - print("x", x.shape) - - noise_scale = torch.tensor([1.0], dtype=torch.float32) - length_scale = torch.tensor([1.0], dtype=torch.float32) - - mel = self.model.run( - [self.model.get_outputs()[0].name], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lengths.numpy(), - self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: length_scale.numpy(), - }, - )[0] - # mel: (batch_size, feat_dim, num_frames) - - return torch.from_numpy(mel) - - -@torch.no_grad() -def main(): - params = get_parser().parse_args() - logging.info(vars(params)) - - model = OnnxModel(params.acoustic_model, params.tokens) - vocoder = OnnxHifiGANModel(params.vocoder) - text = params.input_text - x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) - x = torch.tensor(x, dtype=torch.int64) - - start_t = dt.datetime.now() - mel = model(x) - end_t = dt.datetime.now() - - start_t2 = dt.datetime.now() - audio = vocoder(mel) - end_t2 = dt.datetime.now() - - print("audio", audio.shape) # (1, 1, num_samples) - audio = audio.squeeze() - - sample_rate = model.sample_rate - - t = (end_t - start_t).total_seconds() - t2 = (end_t2 - start_t2).total_seconds() - rtf_am = t * sample_rate / audio.shape[-1] - rtf_vocoder = t2 * sample_rate / audio.shape[-1] - print("RTF for acoustic model ", rtf_am) - print("RTF for vocoder", rtf_vocoder) - - # skip denoiser - sf.write(params.output_wav, audio, sample_rate, "PCM_16") - logging.info(f"Saved to {params.output_wav}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() - -""" - -|HifiGAN |RTF |#Parameters (M)| -|----------|-----|---------------| -|v1 |0.818| 13.926 | -|v2 |0.101| 0.925 | -|v3 |0.118| 1.462 | - -|Num steps|Acoustic Model RTF| -|---------|------------------| -| 2 | 0.039 | -| 3 | 0.047 | -| 4 | 0.071 | -| 5 | 0.076 | -| 6 | 0.103 | - -""" diff --git a/egs/ljspeech/TTS/matcha/requirements.txt b/egs/ljspeech/TTS/matcha/requirements.txt deleted file mode 100644 index d7829c1e1..000000000 --- a/egs/ljspeech/TTS/matcha/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -conformer==0.3.2 -diffusers # developed using version ==0.25.0 -librosa -einops \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/tokenizer.py b/egs/ljspeech/TTS/matcha/tokenizer.py deleted file mode 120000 index 44a19b0f4..000000000 --- a/egs/ljspeech/TTS/matcha/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../vits/tokenizer.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py deleted file mode 100755 index 853042413..000000000 --- a/egs/ljspeech/TTS/matcha/train.py +++ /dev/null @@ -1,719 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - - -import argparse -import json -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Union - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse.utils import fix_random_seed -from model import fix_len_compatibility -from models.matcha_tts import MatchaTTS -from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import LJSpeechTtsDataModule -from utils import MetricsTracker - -from icefall.checkpoint import load_checkpoint, save_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12335, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1000, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=10, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_data_statistics(): - return AttributeDict( - { - "mel_mean": 0, - "mel_std": 1, - } - ) - - -def _get_data_params() -> AttributeDict: - params = AttributeDict( - { - "name": "ljspeech", - "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", - "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - # "batch_size": 64, - # "num_workers": 1, - # "pin_memory": False, - "cleaners": ["english_cleaners2"], - "add_blank": True, - "n_spks": 1, - "n_fft": 1024, - "n_feats": 80, - "sampling_rate": 22050, - "hop_length": 256, - "win_length": 1024, - "f_min": 0, - "f_max": 8000, - "seed": 1234, - "load_durations": False, - "data_statistics": get_data_statistics(), - } - ) - return params - - -def _get_model_params() -> AttributeDict: - n_feats = 80 - filter_channels_dp = 256 - encoder_params_p_dropout = 0.1 - params = AttributeDict( - { - "n_spks": 1, # for ljspeech. - "spk_emb_dim": 64, - "n_feats": n_feats, - "out_size": None, # or use 172 - "prior_loss": True, - "use_precomputed_durations": False, - "data_statistics": get_data_statistics(), - "encoder": AttributeDict( - { - "encoder_type": "RoPE Encoder", # not used - "encoder_params": AttributeDict( - { - "n_feats": n_feats, - "n_channels": 192, - "filter_channels": 768, - "filter_channels_dp": filter_channels_dp, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - "spk_emb_dim": 64, - "n_spks": 1, - "prenet": True, - } - ), - "duration_predictor_params": AttributeDict( - { - "filter_channels_dp": filter_channels_dp, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - } - ), - } - ), - "decoder": AttributeDict( - { - "channels": [256, 256], - "dropout": 0.05, - "attention_head_dim": 64, - "n_blocks": 1, - "num_mid_blocks": 2, - "num_heads": 2, - "act_fn": "snakebeta", - } - ), - "cfm": AttributeDict( - { - "name": "CFM", - "solver": "euler", - "sigma_min": 1e-4, - } - ), - "optimizer": AttributeDict( - { - "lr": 1e-4, - "weight_decay": 0.0, - } - ), - } - ) - - return params - - -def get_params(): - params = AttributeDict( - { - "model_args": _get_model_params(), - "data_args": _get_data_params(), - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 10, - "valid_interval": 1500, - "env_info": get_env_info(), - } - ) - return params - - -def get_model(params): - m = MatchaTTS(**params.model_args) - return m - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): - """Parse batch data""" - mel_mean = params.data_args.data_statistics.mel_mean - mel_std_inv = 1 / params.data_args.data_statistics.mel_std - for i in range(batch["features"].shape[0]): - n = batch["features_lens"][i] - batch["features"][i : i + 1, :n, :] = ( - batch["features"][i : i + 1, :n, :] - mel_mean - ) * mel_std_inv - batch["features"][i : i + 1, n:, :] = 0 - - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - tokens = batch["tokens"] - - tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - max_feature_length = fix_len_compatibility(features.shape[1]) - if max_feature_length > features.shape[1]: - pad = max_feature_length - features.shape[1] - features = torch.nn.functional.pad(features, (0, 0, 0, pad)) - - # features_lens[features_lens.argmax()] += pad - - return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, - rank: int = 0, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) - - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) - - batch_size = len(batch["tokens"]) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - s = 0 - - for key, value in losses.items(): - v = value.detach().item() - loss_info[key] = v * batch_size - s += v * batch_size - - loss_info["tot_loss"] = s - - # summary stats - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["tot_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - optimizer: Optimizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses - - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer=optimizer, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - # audio: (N, T), float32 - # features: (N, T, C), float32 - # audio_lens, (N,), int32 - # features_lens, (N,), int32 - # tokens: List[List[str]], len(tokens) == N - - batch_size = len(batch["tokens"]) - - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) - try: - with autocast(enabled=params.use_fp16): - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) - - loss = sum(losses.values()) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - s = 0 - - for key, value in losses.items(): - v = value.detach().item() - loss_info[key] = v * batch_size - s += v * batch_size - - loss_info["tot_loss"] = s - - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. - # The _growth_interval of the grad scaler is configurable, - # but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, " - f"batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if params.batch_idx_train % params.valid_interval == 1: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - rank=rank, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - "Maximum memory allocated so far is " - f"{torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["tot_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.pad_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - - logging.info(params) - print(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of parameters: {num_param}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) - - logging.info("About to create datamodule") - - ljspeech = LJSpeechTtsDataModule(args) - - train_cuts = ljspeech.train_cuts() - train_dl = ljspeech.train_dataloaders(train_cuts) - - valid_cuts = ljspeech.valid_cuts() - valid_dl = ljspeech.valid_dataloaders(valid_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - fix_random_seed(params.seed + epoch - 1) - if "sampler" in train_dl: - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - optimizer=optimizer, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer=optimizer, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - LJSpeechTtsDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py deleted file mode 100644 index 1e637b766..000000000 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from fbank import MatchaFbank, MatchaFbankConfig -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LJSpeechTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - pin_memory=True, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=True, - pin_memory=True, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" - ) diff --git a/egs/ljspeech/TTS/matcha/utils.py b/egs/ljspeech/TTS/matcha/utils.py deleted file mode 120000 index c2144f8e0..000000000 --- a/egs/ljspeech/TTS/matcha/utils.py +++ /dev/null @@ -1 +0,0 @@ -../vits/utils.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh deleted file mode 100755 index ec5062933..000000000 --- a/egs/ljspeech/TTS/prepare.sh +++ /dev/null @@ -1,190 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: build monotonic_align lib (used by vits and matcha recipes)" - for recipe in vits matcha; do - if [ ! -d $recipe/monotonic_align/build ]; then - cd $recipe/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib for $recipe already built" - fi - done -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # The directory $dl_dir/LJSpeech-1.1 will contain: - # - wavs, which contains the audio files - # - metadata.csv, which provides the transcript text for each audio clip - - # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink - # - # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1 - # - if [ ! -d $dl_dir/LJSpeech-1.1 ]; then - lhotse download ljspeech $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare LJSpeech manifest" - # We assume that you have downloaded the LJSpeech corpus - # to $dl_dir/LJSpeech-1.1 - mkdir -p data/manifests - if [ ! -e data/manifests/.ljspeech.done ]; then - lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests - touch data/manifests/.ljspeech.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute spectrogram for LJSpeech (used by ./vits)" - mkdir -p data/spectrogram - if [ ! -e data/spectrogram/.ljspeech.done ]; then - ./local/compute_spectrogram_ljspeech.py - touch data/spectrogram/.ljspeech.done - fi - - if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then - log "Validating data/spectrogram for LJSpeech (used by ./vits)" - python3 ./local/validate_manifest.py \ - data/spectrogram/ljspeech_cuts_all.jsonl.gz - touch data/spectrogram/.ljspeech-validated.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare phoneme tokens for LJSpeech (used by ./vits)" - # We assume you have installed piper_phonemize and espnet_tts_frontend. - # If not, please install them with: - # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, - # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ - if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then - ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/spectrogram - mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ - data/spectrogram/ljspeech_cuts_all.jsonl.gz - touch data/spectrogram/.ljspeech_with_token.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Split the LJSpeech cuts into train, valid and test sets (used by vits)" - if [ ! -e data/spectrogram/.ljspeech_split.done ]; then - lhotse subset --last 600 \ - data/spectrogram/ljspeech_cuts_all.jsonl.gz \ - data/spectrogram/ljspeech_cuts_validtest.jsonl.gz - lhotse subset --first 100 \ - data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ - data/spectrogram/ljspeech_cuts_valid.jsonl.gz - lhotse subset --last 500 \ - data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ - data/spectrogram/ljspeech_cuts_test.jsonl.gz - - rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) - lhotse subset --first $n \ - data/spectrogram/ljspeech_cuts_all.jsonl.gz \ - data/spectrogram/ljspeech_cuts_train.jsonl.gz - touch data/spectrogram/.ljspeech_split.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Generate token file" - # We assume you have installed piper_phonemize and espnet_tts_frontend. - # If not, please install them with: - # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, - # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ - if [ ! -e data/tokens.txt ]; then - ./local/prepare_token_file.py --tokens data/tokens.txt - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Generate fbank (used by ./matcha)" - mkdir -p data/fbank - if [ ! -e data/fbank/.ljspeech.done ]; then - ./local/compute_fbank_ljspeech.py - touch data/fbank/.ljspeech.done - fi - - if [ ! -e data/fbank/.ljspeech-validated.done ]; then - log "Validating data/fbank for LJSpeech (used by ./matcha)" - python3 ./local/validate_manifest.py \ - data/fbank/ljspeech_cuts_all.jsonl.gz - touch data/fbank/.ljspeech-validated.done - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare phoneme tokens for LJSpeech (used by ./matcha)" - # We assume you have installed piper_phonemize and espnet_tts_frontend. - # If not, please install them with: - # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, - # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ - if [ ! -e data/fbank/.ljspeech_with_token.done ]; then - ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/fbank - mv data/fbank/ljspeech_cuts_with_tokens_all.jsonl.gz \ - data/fbank/ljspeech_cuts_all.jsonl.gz - touch data/fbank/.ljspeech_with_token.done - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Split the LJSpeech cuts into train, valid and test sets (used by ./matcha)" - if [ ! -e data/fbank/.ljspeech_split.done ]; then - lhotse subset --last 600 \ - data/fbank/ljspeech_cuts_all.jsonl.gz \ - data/fbank/ljspeech_cuts_validtest.jsonl.gz - lhotse subset --first 100 \ - data/fbank/ljspeech_cuts_validtest.jsonl.gz \ - data/fbank/ljspeech_cuts_valid.jsonl.gz - lhotse subset --last 500 \ - data/fbank/ljspeech_cuts_validtest.jsonl.gz \ - data/fbank/ljspeech_cuts_test.jsonl.gz - - rm data/fbank/ljspeech_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/fbank/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) - lhotse subset --first $n \ - data/fbank/ljspeech_cuts_all.jsonl.gz \ - data/fbank/ljspeech_cuts_train.jsonl.gz - touch data/fbank/.ljspeech_split.done - fi -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compute fbank mean and std (used by ./matcha)" - if [ ! -f ./data/fbank/cmvn.json ]; then - ./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json - fi -fi diff --git a/egs/ljspeech/TTS/shared b/egs/ljspeech/TTS/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/ljspeech/TTS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md deleted file mode 100644 index f2deed588..000000000 --- a/egs/ljspeech/TTS/vits/README.md +++ /dev/null @@ -1,4 +0,0 @@ -See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials. - -Training logs, Tensorboard logs, and checkpoints are uploaded to -https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28 diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py deleted file mode 100644 index 1a8190014..000000000 --- a/egs/ljspeech/TTS/vits/duration_predictor.py +++ /dev/null @@ -1,193 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Stochastic duration predictor modules in VITS. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - -import math -from typing import Optional - -import torch -import torch.nn.functional as F -from flow import ( - ConvFlow, - DilatedDepthSeparableConv, - ElementwiseAffineFlow, - FlipFlow, - LogFlow, -) - - -class StochasticDurationPredictor(torch.nn.Module): - """Stochastic duration predictor module. - - This is a module of stochastic duration predictor described in `Conditional - Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. - - .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End - Text-to-Speech`: https://arxiv.org/abs/2006.04558 - - """ - - def __init__( - self, - channels: int = 192, - kernel_size: int = 3, - dropout_rate: float = 0.5, - flows: int = 4, - dds_conv_layers: int = 3, - global_channels: int = -1, - ): - """Initialize StochasticDurationPredictor module. - - Args: - channels (int): Number of channels. - kernel_size (int): Kernel size. - dropout_rate (float): Dropout rate. - flows (int): Number of flows. - dds_conv_layers (int): Number of conv layers in DDS conv. - global_channels (int): Number of global conditioning channels. - - """ - super().__init__() - - self.pre = torch.nn.Conv1d(channels, channels, 1) - self.dds = DilatedDepthSeparableConv( - channels, - kernel_size, - layers=dds_conv_layers, - dropout_rate=dropout_rate, - ) - self.proj = torch.nn.Conv1d(channels, channels, 1) - - self.log_flow = LogFlow() - self.flows = torch.nn.ModuleList() - self.flows += [ElementwiseAffineFlow(2)] - for i in range(flows): - self.flows += [ - ConvFlow( - 2, - channels, - kernel_size, - layers=dds_conv_layers, - ) - ] - self.flows += [FlipFlow()] - - self.post_pre = torch.nn.Conv1d(1, channels, 1) - self.post_dds = DilatedDepthSeparableConv( - channels, - kernel_size, - layers=dds_conv_layers, - dropout_rate=dropout_rate, - ) - self.post_proj = torch.nn.Conv1d(channels, channels, 1) - self.post_flows = torch.nn.ModuleList() - self.post_flows += [ElementwiseAffineFlow(2)] - for i in range(flows): - self.post_flows += [ - ConvFlow( - 2, - channels, - kernel_size, - layers=dds_conv_layers, - ) - ] - self.post_flows += [FlipFlow()] - - if global_channels > 0: - self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - w: Optional[torch.Tensor] = None, - g: Optional[torch.Tensor] = None, - inverse: bool = False, - noise_scale: float = 1.0, - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T_text). - x_mask (Tensor): Mask tensor (B, 1, T_text). - w (Optional[Tensor]): Duration tensor (B, 1, T_text). - g (Optional[Tensor]): Global conditioning tensor (B, channels, 1) - inverse (bool): Whether to inverse the flow. - noise_scale (float): Noise scale value. - - Returns: - Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,). - If inverse, log-duration tensor (B, 1, T_text). - - """ - x = x.detach() # stop gradient - x = self.pre(x) - if g is not None: - x = x + self.global_conv(g.detach()) # stop gradient - x = self.dds(x, x_mask) - x = self.proj(x) * x_mask - - if not inverse: - assert w is not None, "w must be provided." - h_w = self.post_pre(w) - h_w = self.post_dds(h_w, x_mask) - h_w = self.post_proj(h_w) * x_mask - e_q = ( - torch.randn( - w.size(0), - 2, - w.size(2), - ).to(device=x.device, dtype=x.dtype) - * x_mask - ) - z_q = e_q - logdet_tot_q = 0.0 - for flow in self.post_flows: - z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) - logdet_tot_q += logdet_q - z_u, z1 = torch.split(z_q, [1, 1], 1) - u = torch.sigmoid(z_u) * x_mask - z0 = (w - u) * x_mask - logdet_tot_q += torch.sum( - (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] - ) - logq = ( - torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - - logdet_tot_q - ) - - logdet_tot = 0 - z0, logdet = self.log_flow(z0, x_mask) - logdet_tot += logdet - z = torch.cat([z0, z1], 1) - for flow in self.flows: - z, logdet = flow(z, x_mask, g=x, inverse=inverse) - logdet_tot = logdet_tot + logdet - nll = ( - torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - - logdet_tot - ) - return nll + logq # (B,) - else: - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = ( - torch.randn( - x.size(0), - 2, - x.size(2), - ).to(device=x.device, dtype=x.dtype) - * noise_scale - ) - for flow in flows: - z = flow(z, x_mask, g=x, inverse=inverse) - z0, z1 = z.split(1, 1) - logw = z0 - return logw diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py deleted file mode 100755 index 0740757c0..000000000 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ /dev/null @@ -1,416 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script exports a VITS model from PyTorch to ONNX. - -Export the model to ONNX: -./vits/export-onnx.py \ - --epoch 1000 \ - --exp-dir vits/exp \ - --tokens data/tokens.txt - -It will generate one file inside vits/exp: - - vits-epoch-1000.onnx - -See ./test_onnx.py for how to use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import onnx -import torch -import torch.nn as nn -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--model-type", - type=str, - default="high", - choices=["low", "medium", "high"], - help="""If not empty, valid values are: low, medium, high. - It controls the model size. low -> runs faster. - """, - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -class OnnxModel(nn.Module): - """A wrapper for VITS generator.""" - - def __init__(self, model: nn.Module): - """ - Args: - model: - A VITS generator. - frame_shift: - The frame shift in samples. - """ - super().__init__() - self.model = model - - def forward( - self, - tokens: torch.Tensor, - tokens_lens: torch.Tensor, - noise_scale: float = 0.667, - alpha: float = 1.0, - noise_scale_dur: float = 0.8, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of VITS.inference_batch - - Args: - tokens: - Input text token indexes (1, T_text) - tokens_lens: - Number of tokens of shape (1,) - noise_scale (float): - Noise scale parameter for flow. - noise_scale_dur (float): - Noise scale parameter for duration predictor. - alpha (float): - Alpha parameter to control the speed of generated speech. - - Returns: - Return a tuple containing: - - audio, generated wavform tensor, (B, T_wav) - """ - audio, _, _ = self.model.generator.inference( - text=tokens, - text_lengths=tokens_lens, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - alpha=alpha, - ) - return audio - - -def export_model_onnx( - model: nn.Module, - model_filename: str, - vocab_size: int, - opset_version: int = 11, -) -> None: - """Export the given generator model to ONNX format. - The exported model has one input: - - - tokens, a tensor of shape (1, T_text); dtype is torch.int64 - - and it has one output: - - - audio, a tensor of shape (1, T'); dtype is torch.float32 - - Args: - model: - The VITS generator. - model_filename: - The filename to save the exported ONNX model. - vocab_size: - Number of tokens used in training. - opset_version: - The opset version to use. - """ - tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) - noise_scale = torch.tensor([1], dtype=torch.float32) - noise_scale_dur = torch.tensor([1], dtype=torch.float32) - alpha = torch.tensor([1], dtype=torch.float32) - - torch.onnx.export( - model, - (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), - model_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "tokens", - "tokens_lens", - "noise_scale", - "alpha", - "noise_scale_dur", - ], - output_names=["audio"], - dynamic_axes={ - "tokens": {0: "N", 1: "T"}, - "tokens_lens": {0: "N"}, - "audio": {0: "N", 1: "T"}, - }, - ) - - if model.model.spks is None: - num_speakers = 1 - else: - num_speakers = model.model.spks - - meta_data = { - "model_type": "vits", - "version": "1", - "model_author": "k2-fsa", - "comment": "icefall", # must be icefall for models from icefall - "language": "English", - "voice": "en-us", # Choose your language appropriately - "has_espeak": 1, - "n_speakers": num_speakers, - "sample_rate": model.model.sampling_rate, # Must match the real sample rate - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=model_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - model.to("cpu") - model.eval() - - model = OnnxModel(model=model) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") - - suffix = f"epoch-{params.epoch}" - - opset_version = 13 - - logging.info("Exporting encoder") - model_filename = params.exp_dir / f"vits-{suffix}.onnx" - export_model_onnx( - model, - model_filename, - params.vocab_size, - opset_version=opset_version, - ) - logging.info(f"Exported generator to {model_filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() - -""" -Supported languages. - -LJSpeech is using "en-us" from the second column. - -Pty Language Age/Gender VoiceName File Other Languages - 5 af --/M Afrikaans gmw/af - 5 am --/M Amharic sem/am - 5 an --/M Aragonese roa/an - 5 ar --/M Arabic sem/ar - 5 as --/M Assamese inc/as - 5 az --/M Azerbaijani trk/az - 5 ba --/M Bashkir trk/ba - 5 be --/M Belarusian zle/be - 5 bg --/M Bulgarian zls/bg - 5 bn --/M Bengali inc/bn - 5 bpy --/M Bishnupriya_Manipuri inc/bpy - 5 bs --/M Bosnian zls/bs - 5 ca --/M Catalan roa/ca - 5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr - 5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5) - 5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5) - 5 cs --/M Czech zlw/cs - 5 cv --/M Chuvash trk/cv - 5 cy --/M Welsh cel/cy - 5 da --/M Danish gmq/da - 5 de --/M German gmw/de - 5 el --/M Greek grk/el - 5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10) - 2 en-gb --/M English_(Great_Britain) gmw/en (en 2) - 5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4) - 5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5) - 5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9) - 5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5) - 2 en-us --/M English_(America) gmw/en-US (en 3) - 5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc - 5 eo --/M Esperanto art/eo - 5 es --/M Spanish_(Spain) roa/es - 5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6) - 5 et --/M Estonian urj/et - 5 eu --/M Basque eu - 5 fa --/M Persian ira/fa - 5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn - 5 fi --/M Finnish urj/fi - 5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8) - 5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8) - 5 fr-fr --/M French_(France) roa/fr (fr 5) - 5 ga --/M Gaelic_(Irish) cel/ga - 5 gd --/M Gaelic_(Scottish) cel/gd - 5 gn --/M Guarani sai/gn - 5 grc --/M Greek_(Ancient) grk/grc - 5 gu --/M Gujarati inc/gu - 5 hak --/M Hakka_Chinese sit/hak - 5 haw --/M Hawaiian map/haw - 5 he --/M Hebrew sem/he - 5 hi --/M Hindi inc/hi - 5 hr --/M Croatian zls/hr (hbs 5) - 5 ht --/M Haitian_Creole roa/ht - 5 hu --/M Hungarian urj/hu - 5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5) - 5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8) - 5 ia --/M Interlingua art/ia - 5 id --/M Indonesian poz/id - 5 io --/M Ido art/io - 5 is --/M Icelandic gmq/is - 5 it --/M Italian roa/it - 5 ja --/M Japanese jpx/ja - 5 jbo --/M Lojban art/jbo - 5 ka --/M Georgian ccs/ka - 5 kk --/M Kazakh trk/kk - 5 kl --/M Greenlandic esx/kl - 5 kn --/M Kannada dra/kn - 5 ko --/M Korean ko - 5 kok --/M Konkani inc/kok - 5 ku --/M Kurdish ira/ku - 5 ky --/M Kyrgyz trk/ky - 5 la --/M Latin itc/la - 5 lb --/M Luxembourgish gmw/lb - 5 lfn --/M Lingua_Franca_Nova art/lfn - 5 lt --/M Lithuanian bat/lt - 5 ltg --/M Latgalian bat/ltg - 5 lv --/M Latvian bat/lv - 5 mi --/M Māori poz/mi - 5 mk --/M Macedonian zls/mk - 5 ml --/M Malayalam dra/ml - 5 mr --/M Marathi inc/mr - 5 ms --/M Malay poz/ms - 5 mt --/M Maltese sem/mt - 5 mto --/M Totontepec_Mixe miz/mto - 5 my --/M Myanmar_(Burmese) sit/my - 5 nb --/M Norwegian_Bokmål gmq/nb (no 5) - 5 nci --/M Nahuatl_(Classical) azc/nci - 5 ne --/M Nepali inc/ne - 5 nl --/M Dutch gmw/nl - 5 nog --/M Nogai trk/nog - 5 om --/M Oromo cus/om - 5 or --/M Oriya inc/or - 5 pa --/M Punjabi inc/pa - 5 pap --/M Papiamento roa/pap - 5 piqd --/M Klingon art/piqd - 5 pl --/M Polish zlw/pl - 5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5) - 5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6) - 5 py --/M Pyash art/py - 5 qdb --/M Lang_Belta art/qdb - 5 qu --/M Quechua qu - 5 quc --/M K'iche' myn/quc - 5 qya --/M Quenya art/qya - 5 ro --/M Romanian roa/ro - 5 ru --/M Russian zle/ru - 5 ru-cl --/M Russian_(Classic) zle/ru-cl - 2 ru-lv --/M Russian_(Latvia) zle/ru-LV - 5 sd --/M Sindhi inc/sd - 5 shn --/M Shan_(Tai_Yai) tai/shn - 5 si --/M Sinhala inc/si - 5 sjn --/M Sindarin art/sjn - 5 sk --/M Slovak zlw/sk - 5 sl --/M Slovenian zls/sl - 5 smj --/M Lule_Saami urj/smj - 5 sq --/M Albanian ine/sq - 5 sr --/M Serbian zls/sr - 5 sv --/M Swedish gmq/sv - 5 sw --/M Swahili bnt/sw - 5 ta --/M Tamil dra/ta - 5 te --/M Telugu dra/te - 5 th --/M Thai tai/th - 5 tk --/M Turkmen trk/tk - 5 tn --/M Setswana bnt/tn - 5 tr --/M Turkish trk/tr - 5 tt --/M Tatar trk/tt - 5 ug --/M Uyghur trk/ug - 5 uk --/M Ukrainian zle/uk - 5 ur --/M Urdu inc/ur - 5 uz --/M Uzbek trk/uz - 5 vi --/M Vietnamese_(Northern) aav/vi - 5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central - 5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south - 5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8) - 5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8) -""" diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py deleted file mode 100644 index 2b84f6434..000000000 --- a/egs/ljspeech/TTS/vits/flow.py +++ /dev/null @@ -1,311 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Basic Flow modules used in VITS. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - -import math -from typing import Optional, Tuple, Union - -import torch -from transform import piecewise_rational_quadratic_transform - - -class FlipFlow(torch.nn.Module): - """Flip flow module.""" - - def forward( - self, x: torch.Tensor, *args, inverse: bool = False, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Flipped tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - x = torch.flip(x, [1]) - if not inverse: - logdet = x.new_zeros(x.size(0)) - return x, logdet - else: - return x - - -class LogFlow(torch.nn.Module): - """Log flow module.""" - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - inverse: bool = False, - eps: float = 1e-5, - **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - x_mask (Tensor): Mask tensor (B, 1, T). - inverse (bool): Whether to inverse the flow. - eps (float): Epsilon for log. - - Returns: - Tensor: Output tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - if not inverse: - y = torch.log(torch.clamp_min(x, eps)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - - -class ElementwiseAffineFlow(torch.nn.Module): - """Elementwise affine flow module.""" - - def __init__(self, channels: int): - """Initialize ElementwiseAffineFlow module. - - Args: - channels (int): Number of channels. - - """ - super().__init__() - self.channels = channels - self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1))) - self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1))) - - def forward( - self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - x_lengths (Tensor): Length tensor (B,). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Output tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - if not inverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - -class Transpose(torch.nn.Module): - """Transpose module for torch.nn.Sequential().""" - - def __init__(self, dim1: int, dim2: int): - """Initialize Transpose module.""" - super().__init__() - self.dim1 = dim1 - self.dim2 = dim2 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Transpose.""" - return x.transpose(self.dim1, self.dim2) - - -class DilatedDepthSeparableConv(torch.nn.Module): - """Dilated depth-separable conv module.""" - - def __init__( - self, - channels: int, - kernel_size: int, - layers: int, - dropout_rate: float = 0.0, - eps: float = 1e-5, - ): - """Initialize DilatedDepthSeparableConv module. - - Args: - channels (int): Number of channels. - kernel_size (int): Kernel size. - layers (int): Number of layers. - dropout_rate (float): Dropout rate. - eps (float): Epsilon for layer norm. - - """ - super().__init__() - - self.convs = torch.nn.ModuleList() - for i in range(layers): - dilation = kernel_size**i - padding = (kernel_size * dilation - dilation) // 2 - self.convs += [ - torch.nn.Sequential( - torch.nn.Conv1d( - channels, - channels, - kernel_size, - groups=channels, - dilation=dilation, - padding=padding, - ), - Transpose(1, 2), - torch.nn.LayerNorm( - channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - torch.nn.GELU(), - torch.nn.Conv1d( - channels, - channels, - 1, - ), - Transpose(1, 2), - torch.nn.LayerNorm( - channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - torch.nn.GELU(), - torch.nn.Dropout(dropout_rate), - ) - ] - - def forward( - self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, in_channels, T). - x_mask (Tensor): Mask tensor (B, 1, T). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - - Returns: - Tensor: Output tensor (B, channels, T). - - """ - if g is not None: - x = x + g - for f in self.convs: - y = f(x * x_mask) - x = x + y - return x * x_mask - - -class ConvFlow(torch.nn.Module): - """Convolutional flow module.""" - - def __init__( - self, - in_channels: int, - hidden_channels: int, - kernel_size: int, - layers: int, - bins: int = 10, - tail_bound: float = 5.0, - ): - """Initialize ConvFlow module. - - Args: - in_channels (int): Number of input channels. - hidden_channels (int): Number of hidden channels. - kernel_size (int): Kernel size. - layers (int): Number of layers. - bins (int): Number of bins. - tail_bound (float): Tail bound value. - - """ - super().__init__() - self.half_channels = in_channels // 2 - self.hidden_channels = hidden_channels - self.bins = bins - self.tail_bound = tail_bound - - self.input_conv = torch.nn.Conv1d( - self.half_channels, - hidden_channels, - 1, - ) - self.dds_conv = DilatedDepthSeparableConv( - hidden_channels, - kernel_size, - layers, - dropout_rate=0.0, - ) - self.proj = torch.nn.Conv1d( - hidden_channels, - self.half_channels * (bins * 3 - 1), - 1, - ) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - g: Optional[torch.Tensor] = None, - inverse: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - x_mask (Tensor): Mask tensor (B,). - g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Output tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - xa, xb = x.split(x.size(1) // 2, 1) - h = self.input_conv(xa) - h = self.dds_conv(h, x_mask, g=g) - h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T) - - b, c, t = xa.shape - # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) - - # TODO(kan-bayashi): Understand this calculation - denom = math.sqrt(self.hidden_channels) - unnorm_widths = h[..., : self.bins] / denom - unnorm_heights = h[..., self.bins : 2 * self.bins] / denom - unnorm_derivatives = h[..., 2 * self.bins :] - xb, logdet_abs = piecewise_rational_quadratic_transform( - xb, - unnorm_widths, - unnorm_heights, - unnorm_derivatives, - inverse=inverse, - tails="linear", - tail_bound=self.tail_bound, - ) - x = torch.cat([xa, xb], 1) * x_mask - logdet = torch.sum(logdet_abs * x_mask, [1, 2]) - if not inverse: - return x, logdet - else: - return x diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py deleted file mode 100644 index 521b0121f..000000000 --- a/egs/ljspeech/TTS/vits/generator.py +++ /dev/null @@ -1,535 +0,0 @@ -# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Generator module in VITS. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - - -import math -from typing import List, Optional, Tuple - -import numpy as np -import torch -import torch.nn.functional as F -from duration_predictor import StochasticDurationPredictor -from hifigan import HiFiGANGenerator -from posterior_encoder import PosteriorEncoder -from residual_coupling import ResidualAffineCouplingBlock -from text_encoder import TextEncoder -from utils import get_random_segments - -from icefall.utils import make_pad_mask - - -class VITSGenerator(torch.nn.Module): - """Generator module in VITS, `Conditional Variational Autoencoder - with Adversarial Learning for End-to-End Text-to-Speech`. - """ - - def __init__( - self, - vocabs: int, - aux_channels: int = 513, - hidden_channels: int = 192, - spks: Optional[int] = None, - langs: Optional[int] = None, - spk_embed_dim: Optional[int] = None, - global_channels: int = -1, - segment_size: int = 32, - text_encoder_attention_heads: int = 2, - text_encoder_ffn_expand: int = 4, - text_encoder_cnn_module_kernel: int = 5, - text_encoder_blocks: int = 6, - text_encoder_dropout_rate: float = 0.1, - decoder_kernel_size: int = 7, - decoder_channels: int = 512, - decoder_upsample_scales: List[int] = [8, 8, 2, 2], - decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], - decoder_resblock_kernel_sizes: List[int] = [3, 7, 11], - decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - use_weight_norm_in_decoder: bool = True, - posterior_encoder_kernel_size: int = 5, - posterior_encoder_layers: int = 16, - posterior_encoder_stacks: int = 1, - posterior_encoder_base_dilation: int = 1, - posterior_encoder_dropout_rate: float = 0.0, - use_weight_norm_in_posterior_encoder: bool = True, - flow_flows: int = 4, - flow_kernel_size: int = 5, - flow_base_dilation: int = 1, - flow_layers: int = 4, - flow_dropout_rate: float = 0.0, - use_weight_norm_in_flow: bool = True, - use_only_mean_in_flow: bool = True, - stochastic_duration_predictor_kernel_size: int = 3, - stochastic_duration_predictor_dropout_rate: float = 0.5, - stochastic_duration_predictor_flows: int = 4, - stochastic_duration_predictor_dds_conv_layers: int = 3, - ): - """Initialize VITS generator module. - - Args: - vocabs (int): Input vocabulary size. - aux_channels (int): Number of acoustic feature channels. - hidden_channels (int): Number of hidden channels. - spks (Optional[int]): Number of speakers. If set to > 1, assume that the - sids will be provided as the input and use sid embedding layer. - langs (Optional[int]): Number of languages. If set to > 1, assume that the - lids will be provided as the input and use sid embedding layer. - spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, - assume that spembs will be provided as the input. - global_channels (int): Number of global conditioning channels. - segment_size (int): Segment size for decoder. - text_encoder_attention_heads (int): Number of heads in conformer block - of text encoder. - text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block - of text encoder. - text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder. - text_encoder_blocks (int): Number of conformer blocks in text encoder. - text_encoder_dropout_rate (float): Dropout rate in conformer block of - text encoder. - decoder_kernel_size (int): Decoder kernel size. - decoder_channels (int): Number of decoder initial channels. - decoder_upsample_scales (List[int]): List of upsampling scales in decoder. - decoder_upsample_kernel_sizes (List[int]): List of kernel size for - upsampling layers in decoder. - decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks - in decoder. - decoder_resblock_dilations (List[List[int]]): List of list of dilations for - resblocks in decoder. - use_weight_norm_in_decoder (bool): Whether to apply weight normalization in - decoder. - posterior_encoder_kernel_size (int): Posterior encoder kernel size. - posterior_encoder_layers (int): Number of layers of posterior encoder. - posterior_encoder_stacks (int): Number of stacks of posterior encoder. - posterior_encoder_base_dilation (int): Base dilation of posterior encoder. - posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. - use_weight_norm_in_posterior_encoder (bool): Whether to apply weight - normalization in posterior encoder. - flow_flows (int): Number of flows in flow. - flow_kernel_size (int): Kernel size in flow. - flow_base_dilation (int): Base dilation in flow. - flow_layers (int): Number of layers in flow. - flow_dropout_rate (float): Dropout rate in flow - use_weight_norm_in_flow (bool): Whether to apply weight normalization in - flow. - use_only_mean_in_flow (bool): Whether to use only mean in flow. - stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic - duration predictor. - stochastic_duration_predictor_dropout_rate (float): Dropout rate in - stochastic duration predictor. - stochastic_duration_predictor_flows (int): Number of flows in stochastic - duration predictor. - stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv - layers in stochastic duration predictor. - - """ - super().__init__() - self.segment_size = segment_size - self.text_encoder = TextEncoder( - vocabs=vocabs, - d_model=hidden_channels, - num_heads=text_encoder_attention_heads, - dim_feedforward=hidden_channels * text_encoder_ffn_expand, - cnn_module_kernel=text_encoder_cnn_module_kernel, - num_layers=text_encoder_blocks, - dropout=text_encoder_dropout_rate, - ) - self.decoder = HiFiGANGenerator( - in_channels=hidden_channels, - out_channels=1, - channels=decoder_channels, - global_channels=global_channels, - kernel_size=decoder_kernel_size, - upsample_scales=decoder_upsample_scales, - upsample_kernel_sizes=decoder_upsample_kernel_sizes, - resblock_kernel_sizes=decoder_resblock_kernel_sizes, - resblock_dilations=decoder_resblock_dilations, - use_weight_norm=use_weight_norm_in_decoder, - ) - self.posterior_encoder = PosteriorEncoder( - in_channels=aux_channels, - out_channels=hidden_channels, - hidden_channels=hidden_channels, - kernel_size=posterior_encoder_kernel_size, - layers=posterior_encoder_layers, - stacks=posterior_encoder_stacks, - base_dilation=posterior_encoder_base_dilation, - global_channels=global_channels, - dropout_rate=posterior_encoder_dropout_rate, - use_weight_norm=use_weight_norm_in_posterior_encoder, - ) - self.flow = ResidualAffineCouplingBlock( - in_channels=hidden_channels, - hidden_channels=hidden_channels, - flows=flow_flows, - kernel_size=flow_kernel_size, - base_dilation=flow_base_dilation, - layers=flow_layers, - global_channels=global_channels, - dropout_rate=flow_dropout_rate, - use_weight_norm=use_weight_norm_in_flow, - use_only_mean=use_only_mean_in_flow, - ) - # TODO(kan-bayashi): Add deterministic version as an option - self.duration_predictor = StochasticDurationPredictor( - channels=hidden_channels, - kernel_size=stochastic_duration_predictor_kernel_size, - dropout_rate=stochastic_duration_predictor_dropout_rate, - flows=stochastic_duration_predictor_flows, - dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, - global_channels=global_channels, - ) - - self.upsample_factor = int(np.prod(decoder_upsample_scales)) - self.spks = None - if spks is not None and spks > 1: - assert global_channels > 0, global_channels - self.spks = spks - self.global_emb = torch.nn.Embedding(spks, global_channels) - self.spk_embed_dim = None - if spk_embed_dim is not None and spk_embed_dim > 0: - assert global_channels > 0 - self.spk_embed_dim = spk_embed_dim - self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels) - self.langs = None - if langs is not None and langs > 1: - assert global_channels > 0 - self.langs = langs - self.lang_emb = torch.nn.Embedding(langs, global_channels) - - # delayed import - from monotonic_align import maximum_path - - self.maximum_path = maximum_path - - def forward( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: torch.Tensor, - feats_lengths: torch.Tensor, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ], - ]: - """Calculate forward propagation. - - Args: - text (Tensor): Text index tensor (B, T_text). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, aux_channels, T_feats). - feats_lengths (Tensor): Feature length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - - Returns: - Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). - Tensor: Duration negative log-likelihood (NLL) tensor (B,). - Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). - Tensor: Segments start index tensor (B,). - Tensor: Text mask tensor (B, 1, T_text). - Tensor: Feature mask tensor (B, 1, T_feats). - tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - - Tensor: Posterior encoder hidden representation (B, H, T_feats). - - Tensor: Flow hidden representation (B, H, T_feats). - - Tensor: Expanded text encoder projected mean (B, H, T_feats). - - Tensor: Expanded text encoder projected scale (B, H, T_feats). - - Tensor: Posterior encoder projected mean (B, H, T_feats). - - Tensor: Posterior encoder projected scale (B, H, T_feats). - - """ - # forward text encoder - x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) - - # calculate global conditioning - g = None - if self.spks is not None: - # speaker one-hot vector embedding: (B, global_channels, 1) - g = self.global_emb(sids.view(-1)).unsqueeze(-1) - if self.spk_embed_dim is not None: - # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) - g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) - if g is None: - g = g_ - else: - g = g + g_ - if self.langs is not None: - # language one-hot vector embedding: (B, global_channels, 1) - g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) - if g is None: - g = g_ - else: - g = g + g_ - - # forward posterior encoder - z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) - - # forward flow - z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) - - # monotonic alignment search - with torch.no_grad(): - # negative cross-entropy - s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) - # (B, 1, T_text) - neg_x_ent_1 = torch.sum( - -0.5 * math.log(2 * math.pi) - logs_p, - [1], - keepdim=True, - ) - # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) - neg_x_ent_2 = torch.matmul( - -0.5 * (z_p**2).transpose(1, 2), - s_p_sq_r, - ) - # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) - neg_x_ent_3 = torch.matmul( - z_p.transpose(1, 2), - (m_p * s_p_sq_r), - ) - # (B, 1, T_text) - neg_x_ent_4 = torch.sum( - -0.5 * (m_p**2) * s_p_sq_r, - [1], - keepdim=True, - ) - # (B, T_feats, T_text) - neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 - # (B, 1, T_feats, T_text) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - # monotonic attention weight: (B, 1, T_feats, T_text) - attn = ( - self.maximum_path( - neg_x_ent, - attn_mask.squeeze(1), - ) - .unsqueeze(1) - .detach() - ) - - # forward duration predictor - w = attn.sum(2) # (B, 1, T_text) - dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) - dur_nll = dur_nll / torch.sum(x_mask) - - # expand the length to match with the feature sequence - # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) - # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) - - # get random segments - z_segments, z_start_idxs = get_random_segments( - z, - feats_lengths, - self.segment_size, - ) - - # forward decoder with random segments - wav = self.decoder(z_segments, g=g) - - return ( - wav, - dur_nll, - attn, - z_start_idxs, - x_mask, - y_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) - - def inference( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: Optional[torch.Tensor] = None, - feats_lengths: Optional[torch.Tensor] = None, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - dur: Optional[torch.Tensor] = None, - noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, - alpha: float = 1.0, - max_len: Optional[int] = None, - use_teacher_forcing: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Run inference. - - Args: - text (Tensor): Input text index tensor (B, T_text,). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, aux_channels, T_feats,). - feats_lengths (Tensor): Feature length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, - skip the prediction of durations (i.e., teacher forcing). - noise_scale (float): Noise scale parameter for flow. - noise_scale_dur (float): Noise scale parameter for duration predictor. - alpha (float): Alpha parameter to control the speed of generated speech. - max_len (Optional[int]): Maximum length of acoustic feature sequence. - use_teacher_forcing (bool): Whether to use teacher forcing. - - Returns: - Tensor: Generated waveform tensor (B, T_wav). - Tensor: Monotonic attention weight tensor (B, T_feats, T_text). - Tensor: Duration tensor (B, T_text). - - """ - # encoder - x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) - x_mask = x_mask.to(x.dtype) - g = None - if self.spks is not None: - # (B, global_channels, 1) - g = self.global_emb(sids.view(-1)).unsqueeze(-1) - if self.spk_embed_dim is not None: - # (B, global_channels, 1) - if spembs.ndim == 2: - g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) - elif spembs.ndim == 1: - g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) - else: - raise ValueError("spembs should be 1D or 2D (batch mode) tensor.") - if g is None: - g = g_ - else: - g = g + g_ - if self.langs is not None: - # (B, global_channels, 1) - g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) - if g is None: - g = g_ - else: - g = g + g_ - - if use_teacher_forcing: - # forward posterior encoder - z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) - - # forward flow - z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) - - # monotonic alignment search - s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) - # (B, 1, T_text) - neg_x_ent_1 = torch.sum( - -0.5 * math.log(2 * math.pi) - logs_p, - [1], - keepdim=True, - ) - # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) - neg_x_ent_2 = torch.matmul( - -0.5 * (z_p**2).transpose(1, 2), - s_p_sq_r, - ) - # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) - neg_x_ent_3 = torch.matmul( - z_p.transpose(1, 2), - (m_p * s_p_sq_r), - ) - # (B, 1, T_text) - neg_x_ent_4 = torch.sum( - -0.5 * (m_p**2) * s_p_sq_r, - [1], - keepdim=True, - ) - # (B, T_feats, T_text) - neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 - # (B, 1, T_feats, T_text) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - # monotonic attention weight: (B, 1, T_feats, T_text) - attn = self.maximum_path( - neg_x_ent, - attn_mask.squeeze(1), - ).unsqueeze(1) - dur = attn.sum(2) # (B, 1, T_text) - - # forward decoder with random segments - wav = self.decoder(z * y_mask, g=g) - else: - # duration - if dur is None: - logw = self.duration_predictor( - x, - x_mask, - g=g, - inverse=True, - noise_scale=noise_scale_dur, - ) - w = torch.exp(logw) * x_mask * alpha - dur = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() - y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) - y_mask = y_mask.to(x.dtype) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = self._generate_path(dur, attn_mask) - - # expand the length to match with the feature sequence - # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) - m_p = torch.matmul( - attn.squeeze(1), - m_p.transpose(1, 2), - ).transpose(1, 2) - # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) - logs_p = torch.matmul( - attn.squeeze(1), - logs_p.transpose(1, 2), - ).transpose(1, 2) - - # decoder - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - z = self.flow(z_p, y_mask, g=g, inverse=True) - wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) - - return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) - - def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - """Generate path a.k.a. monotonic attention. - - Args: - dur (Tensor): Duration tensor (B, 1, T_text). - mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text). - - Returns: - Tensor: Path tensor (B, 1, T_feats, T_text). - - """ - b, _, t_y, t_x = mask.shape - cum_dur = torch.cumsum(dur, -1) - cum_dur_flat = cum_dur.view(b * t_x) - path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) - path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) - # path = path.view(b, t_x, t_y).to(dtype=mask.dtype) - path = path.view(b, t_x, t_y).to(dtype=torch.float) - # path will be like (t_x = 3, t_y = 5): - # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], - # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], - # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] - path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] - # path = path.to(dtype=mask.dtype) - return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/TTS/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py deleted file mode 100644 index 589ac30f6..000000000 --- a/egs/ljspeech/TTS/vits/hifigan.py +++ /dev/null @@ -1,933 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""HiFi-GAN Modules. - -This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. - -""" - -import copy -import logging -from typing import Any, Dict, List, Optional - -import numpy as np -import torch -import torch.nn.functional as F - - -class HiFiGANGenerator(torch.nn.Module): - """HiFiGAN generator module.""" - - def __init__( - self, - in_channels: int = 80, - out_channels: int = 1, - channels: int = 512, - global_channels: int = -1, - kernel_size: int = 7, - upsample_scales: List[int] = [8, 8, 2, 2], - upsample_kernel_sizes: List[int] = [16, 16, 4, 4], - resblock_kernel_sizes: List[int] = [3, 7, 11], - resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - use_additional_convs: bool = True, - bias: bool = True, - nonlinear_activation: str = "LeakyReLU", - nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, - use_weight_norm: bool = True, - ): - """Initialize HiFiGANGenerator module. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - channels (int): Number of hidden representation channels. - global_channels (int): Number of global conditioning channels. - kernel_size (int): Kernel size of initial and final conv layer. - upsample_scales (List[int]): List of upsampling scales. - upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. - resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. - resblock_dilations (List[List[int]]): List of list of dilations for residual - blocks. - use_additional_convs (bool): Whether to use additional conv layers in - residual blocks. - bias (bool): Whether to add bias parameter in convolution layers. - nonlinear_activation (str): Activation function module name. - nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation - function. - use_weight_norm (bool): Whether to use weight norm. If set to true, it will - be applied to all of the conv layers. - - """ - super().__init__() - - # check hyperparameters are valid - assert kernel_size % 2 == 1, "Kernel size must be odd number." - assert len(upsample_scales) == len(upsample_kernel_sizes) - assert len(resblock_dilations) == len(resblock_kernel_sizes) - - # define modules - self.upsample_factor = int(np.prod(upsample_scales) * out_channels) - self.num_upsamples = len(upsample_kernel_sizes) - self.num_blocks = len(resblock_kernel_sizes) - self.input_conv = torch.nn.Conv1d( - in_channels, - channels, - kernel_size, - 1, - padding=(kernel_size - 1) // 2, - ) - self.upsamples = torch.nn.ModuleList() - self.blocks = torch.nn.ModuleList() - for i in range(len(upsample_kernel_sizes)): - assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] - self.upsamples += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - torch.nn.ConvTranspose1d( - channels // (2**i), - channels // (2 ** (i + 1)), - upsample_kernel_sizes[i], - upsample_scales[i], - padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, - output_padding=upsample_scales[i] % 2, - ), - ) - ] - for j in range(len(resblock_kernel_sizes)): - self.blocks += [ - ResidualBlock( - kernel_size=resblock_kernel_sizes[j], - channels=channels // (2 ** (i + 1)), - dilations=resblock_dilations[j], - bias=bias, - use_additional_convs=use_additional_convs, - nonlinear_activation=nonlinear_activation, - nonlinear_activation_params=nonlinear_activation_params, - ) - ] - self.output_conv = torch.nn.Sequential( - # NOTE(kan-bayashi): follow official implementation but why - # using different slope parameter here? (0.1 vs. 0.01) - torch.nn.LeakyReLU(), - torch.nn.Conv1d( - channels // (2 ** (i + 1)), - out_channels, - kernel_size, - 1, - padding=(kernel_size - 1) // 2, - ), - torch.nn.Tanh(), - ) - if global_channels > 0: - self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) - - # apply weight norm - if use_weight_norm: - self.apply_weight_norm() - - # reset parameters - self.reset_parameters() - - def forward( - self, c: torch.Tensor, g: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - c (Tensor): Input tensor (B, in_channels, T). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - - Returns: - Tensor: Output tensor (B, out_channels, T). - - """ - c = self.input_conv(c) - if g is not None: - c = c + self.global_conv(g) - for i in range(self.num_upsamples): - c = self.upsamples[i](c) - cs = 0.0 # initialize - for j in range(self.num_blocks): - cs += self.blocks[i * self.num_blocks + j](c) - c = cs / self.num_blocks - c = self.output_conv(c) - - return c - - def reset_parameters(self): - """Reset parameters. - - This initialization follows the official implementation manner. - https://github.com/jik876/hifi-gan/blob/master/models.py - - """ - - def _reset_parameters(m: torch.nn.Module): - if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): - m.weight.data.normal_(0.0, 0.01) - logging.debug(f"Reset parameters in {m}.") - - self.apply(_reset_parameters) - - def remove_weight_norm(self): - """Remove weight normalization module from all of the layers.""" - - def _remove_weight_norm(m: torch.nn.Module): - try: - logging.debug(f"Weight norm is removed from {m}.") - torch.nn.utils.remove_weight_norm(m) - except ValueError: # this module didn't have weight norm - return - - self.apply(_remove_weight_norm) - - def apply_weight_norm(self): - """Apply weight normalization module from all of the layers.""" - - def _apply_weight_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv1d) or isinstance( - m, torch.nn.ConvTranspose1d - ): - torch.nn.utils.weight_norm(m) - logging.debug(f"Weight norm is applied to {m}.") - - self.apply(_apply_weight_norm) - - def inference( - self, c: torch.Tensor, g: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Perform inference. - - Args: - c (torch.Tensor): Input tensor (T, in_channels). - g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). - - Returns: - Tensor: Output tensor (T ** upsample_factor, out_channels). - - """ - if g is not None: - g = g.unsqueeze(0) - c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g) - return c.squeeze(0).transpose(1, 0) - - -class ResidualBlock(torch.nn.Module): - """Residual block module in HiFiGAN.""" - - def __init__( - self, - kernel_size: int = 3, - channels: int = 512, - dilations: List[int] = [1, 3, 5], - bias: bool = True, - use_additional_convs: bool = True, - nonlinear_activation: str = "LeakyReLU", - nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, - ): - """Initialize ResidualBlock module. - - Args: - kernel_size (int): Kernel size of dilation convolution layer. - channels (int): Number of channels for convolution layer. - dilations (List[int]): List of dilation factors. - use_additional_convs (bool): Whether to use additional convolution layers. - bias (bool): Whether to add bias parameter in convolution layers. - nonlinear_activation (str): Activation function module name. - nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation - function. - - """ - super().__init__() - self.use_additional_convs = use_additional_convs - self.convs1 = torch.nn.ModuleList() - if use_additional_convs: - self.convs2 = torch.nn.ModuleList() - assert kernel_size % 2 == 1, "Kernel size must be odd number." - for dilation in dilations: - self.convs1 += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - torch.nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation, - bias=bias, - padding=(kernel_size - 1) // 2 * dilation, - ), - ) - ] - if use_additional_convs: - self.convs2 += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - torch.nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - bias=bias, - padding=(kernel_size - 1) // 2, - ), - ) - ] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - - Returns: - Tensor: Output tensor (B, channels, T). - - """ - for idx in range(len(self.convs1)): - xt = self.convs1[idx](x) - if self.use_additional_convs: - xt = self.convs2[idx](xt) - x = xt + x - return x - - -class HiFiGANPeriodDiscriminator(torch.nn.Module): - """HiFiGAN period discriminator module.""" - - def __init__( - self, - in_channels: int = 1, - out_channels: int = 1, - period: int = 3, - kernel_sizes: List[int] = [5, 3], - channels: int = 32, - downsample_scales: List[int] = [3, 3, 3, 3, 1], - max_downsample_channels: int = 1024, - bias: bool = True, - nonlinear_activation: str = "LeakyReLU", - nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, - use_weight_norm: bool = True, - use_spectral_norm: bool = False, - ): - """Initialize HiFiGANPeriodDiscriminator module. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - period (int): Period. - kernel_sizes (list): Kernel sizes of initial conv layers and the final conv - layer. - channels (int): Number of initial channels. - downsample_scales (List[int]): List of downsampling scales. - max_downsample_channels (int): Number of maximum downsampling channels. - use_additional_convs (bool): Whether to use additional conv layers in - residual blocks. - bias (bool): Whether to add bias parameter in convolution layers. - nonlinear_activation (str): Activation function module name. - nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation - function. - use_weight_norm (bool): Whether to use weight norm. - If set to true, it will be applied to all of the conv layers. - use_spectral_norm (bool): Whether to use spectral norm. - If set to true, it will be applied to all of the conv layers. - - """ - super().__init__() - assert len(kernel_sizes) == 2 - assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." - assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." - - self.period = period - self.convs = torch.nn.ModuleList() - in_chs = in_channels - out_chs = channels - for downsample_scale in downsample_scales: - self.convs += [ - torch.nn.Sequential( - torch.nn.Conv2d( - in_chs, - out_chs, - (kernel_sizes[0], 1), - (downsample_scale, 1), - padding=((kernel_sizes[0] - 1) // 2, 0), - ), - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - ) - ] - in_chs = out_chs - # NOTE(kan-bayashi): Use downsample_scale + 1? - out_chs = min(out_chs * 4, max_downsample_channels) - self.output_conv = torch.nn.Conv2d( - out_chs, - out_channels, - (kernel_sizes[1] - 1, 1), - 1, - padding=((kernel_sizes[1] - 1) // 2, 0), - ) - - if use_weight_norm and use_spectral_norm: - raise ValueError("Either use use_weight_norm or use_spectral_norm.") - - # apply weight norm - if use_weight_norm: - self.apply_weight_norm() - - # apply spectral norm - if use_spectral_norm: - self.apply_spectral_norm() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Calculate forward propagation. - - Args: - c (Tensor): Input tensor (B, in_channels, T). - - Returns: - list: List of each layer's tensors. - - """ - # transform 1d to 2d -> (B, C, T/P, P) - b, c, t = x.shape - if t % self.period != 0: - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t += n_pad - x = x.view(b, c, t // self.period, self.period) - - # forward conv - outs = [] - for layer in self.convs: - x = layer(x) - outs += [x] - x = self.output_conv(x) - x = torch.flatten(x, 1, -1) - outs += [x] - - return outs - - def apply_weight_norm(self): - """Apply weight normalization module from all of the layers.""" - - def _apply_weight_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv2d): - torch.nn.utils.weight_norm(m) - logging.debug(f"Weight norm is applied to {m}.") - - self.apply(_apply_weight_norm) - - def apply_spectral_norm(self): - """Apply spectral normalization module from all of the layers.""" - - def _apply_spectral_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv2d): - torch.nn.utils.spectral_norm(m) - logging.debug(f"Spectral norm is applied to {m}.") - - self.apply(_apply_spectral_norm) - - -class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): - """HiFiGAN multi-period discriminator module.""" - - def __init__( - self, - periods: List[int] = [2, 3, 5, 7, 11], - discriminator_params: Dict[str, Any] = { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [5, 3], - "channels": 32, - "downsample_scales": [3, 3, 3, 3, 1], - "max_downsample_channels": 1024, - "bias": True, - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - "use_weight_norm": True, - "use_spectral_norm": False, - }, - ): - """Initialize HiFiGANMultiPeriodDiscriminator module. - - Args: - periods (List[int]): List of periods. - discriminator_params (Dict[str, Any]): Parameters for hifi-gan period - discriminator module. The period parameter will be overwritten. - - """ - super().__init__() - self.discriminators = torch.nn.ModuleList() - for period in periods: - params = copy.deepcopy(discriminator_params) - params["period"] = period - self.discriminators += [HiFiGANPeriodDiscriminator(**params)] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input noise signal (B, 1, T). - - Returns: - List: List of list of each discriminator outputs, which consists of each - layer output tensors. - - """ - outs = [] - for f in self.discriminators: - outs += [f(x)] - - return outs - - -class HiFiGANScaleDiscriminator(torch.nn.Module): - """HiFi-GAN scale discriminator module.""" - - def __init__( - self, - in_channels: int = 1, - out_channels: int = 1, - kernel_sizes: List[int] = [15, 41, 5, 3], - channels: int = 128, - max_downsample_channels: int = 1024, - max_groups: int = 16, - bias: int = True, - downsample_scales: List[int] = [2, 2, 4, 4, 1], - nonlinear_activation: str = "LeakyReLU", - nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, - use_weight_norm: bool = True, - use_spectral_norm: bool = False, - ): - """Initilize HiFiGAN scale discriminator module. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - kernel_sizes (List[int]): List of four kernel sizes. The first will be used - for the first conv layer, and the second is for downsampling part, and - the remaining two are for the last two output layers. - channels (int): Initial number of channels for conv layer. - max_downsample_channels (int): Maximum number of channels for downsampling - layers. - bias (bool): Whether to add bias parameter in convolution layers. - downsample_scales (List[int]): List of downsampling scales. - nonlinear_activation (str): Activation function module name. - nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation - function. - use_weight_norm (bool): Whether to use weight norm. If set to true, it will - be applied to all of the conv layers. - use_spectral_norm (bool): Whether to use spectral norm. If set to true, it - will be applied to all of the conv layers. - - """ - super().__init__() - self.layers = torch.nn.ModuleList() - - # check kernel size is valid - assert len(kernel_sizes) == 4 - for ks in kernel_sizes: - assert ks % 2 == 1 - - # add first layer - self.layers += [ - torch.nn.Sequential( - torch.nn.Conv1d( - in_channels, - channels, - # NOTE(kan-bayashi): Use always the same kernel size - kernel_sizes[0], - bias=bias, - padding=(kernel_sizes[0] - 1) // 2, - ), - getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), - ) - ] - - # add downsample layers - in_chs = channels - out_chs = channels - # NOTE(kan-bayashi): Remove hard coding? - groups = 4 - for downsample_scale in downsample_scales: - self.layers += [ - torch.nn.Sequential( - torch.nn.Conv1d( - in_chs, - out_chs, - kernel_size=kernel_sizes[1], - stride=downsample_scale, - padding=(kernel_sizes[1] - 1) // 2, - groups=groups, - bias=bias, - ), - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - ) - ] - in_chs = out_chs - # NOTE(kan-bayashi): Remove hard coding? - out_chs = min(in_chs * 2, max_downsample_channels) - # NOTE(kan-bayashi): Remove hard coding? - groups = min(groups * 4, max_groups) - - # add final layers - out_chs = min(in_chs * 2, max_downsample_channels) - self.layers += [ - torch.nn.Sequential( - torch.nn.Conv1d( - in_chs, - out_chs, - kernel_size=kernel_sizes[2], - stride=1, - padding=(kernel_sizes[2] - 1) // 2, - bias=bias, - ), - getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), - ) - ] - self.layers += [ - torch.nn.Conv1d( - out_chs, - out_channels, - kernel_size=kernel_sizes[3], - stride=1, - padding=(kernel_sizes[3] - 1) // 2, - bias=bias, - ), - ] - - if use_weight_norm and use_spectral_norm: - raise ValueError("Either use use_weight_norm or use_spectral_norm.") - - # apply weight norm - self.use_weight_norm = use_weight_norm - if use_weight_norm: - self.apply_weight_norm() - - # apply spectral norm - self.use_spectral_norm = use_spectral_norm - if use_spectral_norm: - self.apply_spectral_norm() - - # backward compatibility - self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) - - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - """Calculate forward propagation. - - Args: - x (Tensor): Input noise signal (B, 1, T). - - Returns: - List[Tensor]: List of output tensors of each layer. - - """ - outs = [] - for f in self.layers: - x = f(x) - outs += [x] - - return outs - - def apply_weight_norm(self): - """Apply weight normalization module from all of the layers.""" - - def _apply_weight_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv1d): - torch.nn.utils.weight_norm(m) - logging.debug(f"Weight norm is applied to {m}.") - - self.apply(_apply_weight_norm) - - def apply_spectral_norm(self): - """Apply spectral normalization module from all of the layers.""" - - def _apply_spectral_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv1d): - torch.nn.utils.spectral_norm(m) - logging.debug(f"Spectral norm is applied to {m}.") - - self.apply(_apply_spectral_norm) - - def remove_weight_norm(self): - """Remove weight normalization module from all of the layers.""" - - def _remove_weight_norm(m): - try: - logging.debug(f"Weight norm is removed from {m}.") - torch.nn.utils.remove_weight_norm(m) - except ValueError: # this module didn't have weight norm - return - - self.apply(_remove_weight_norm) - - def remove_spectral_norm(self): - """Remove spectral normalization module from all of the layers.""" - - def _remove_spectral_norm(m): - try: - logging.debug(f"Spectral norm is removed from {m}.") - torch.nn.utils.remove_spectral_norm(m) - except ValueError: # this module didn't have weight norm - return - - self.apply(_remove_spectral_norm) - - def _load_state_dict_pre_hook( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - """Fix the compatibility of weight / spectral normalization issue. - - Some pretrained models are trained with configs that use weight / spectral - normalization, but actually, the norm is not applied. This causes the mismatch - of the parameters with configs. To solve this issue, when parameter mismatch - happens in loading pretrained model, we remove the norm from the current model. - - See also: - - https://github.com/espnet/espnet/pull/5240 - - https://github.com/espnet/espnet/pull/5249 - - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409 - - """ - current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)] - if self.use_weight_norm and any( - [k.endswith("weight") for k in current_module_keys] - ): - logging.warning( - "It seems weight norm is not applied in the pretrained model but the" - " current model uses it. To keep the compatibility, we remove the norm" - " from the current model. This may cause unexpected behavior due to the" - " parameter mismatch in finetuning. To avoid this issue, please change" - " the following parameters in config to false:\n" - " - discriminator_params.follow_official_norm\n" - " - discriminator_params.scale_discriminator_params.use_weight_norm\n" - " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" - "\n" - "See also:\n" - " - https://github.com/espnet/espnet/pull/5240\n" - " - https://github.com/espnet/espnet/pull/5249" - ) - self.remove_weight_norm() - self.use_weight_norm = False - for k in current_module_keys: - if k.endswith("weight_g") or k.endswith("weight_v"): - del state_dict[k] - - if self.use_spectral_norm and any( - [k.endswith("weight") for k in current_module_keys] - ): - logging.warning( - "It seems spectral norm is not applied in the pretrained model but the" - " current model uses it. To keep the compatibility, we remove the norm" - " from the current model. This may cause unexpected behavior due to the" - " parameter mismatch in finetuning. To avoid this issue, please change" - " the following parameters in config to false:\n" - " - discriminator_params.follow_official_norm\n" - " - discriminator_params.scale_discriminator_params.use_weight_norm\n" - " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" - "\n" - "See also:\n" - " - https://github.com/espnet/espnet/pull/5240\n" - " - https://github.com/espnet/espnet/pull/5249" - ) - self.remove_spectral_norm() - self.use_spectral_norm = False - for k in current_module_keys: - if ( - k.endswith("weight_u") - or k.endswith("weight_v") - or k.endswith("weight_orig") - ): - del state_dict[k] - - -class HiFiGANMultiScaleDiscriminator(torch.nn.Module): - """HiFi-GAN multi-scale discriminator module.""" - - def __init__( - self, - scales: int = 3, - downsample_pooling: str = "AvgPool1d", - # follow the official implementation setting - downsample_pooling_params: Dict[str, Any] = { - "kernel_size": 4, - "stride": 2, - "padding": 2, - }, - discriminator_params: Dict[str, Any] = { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [15, 41, 5, 3], - "channels": 128, - "max_downsample_channels": 1024, - "max_groups": 16, - "bias": True, - "downsample_scales": [2, 2, 4, 4, 1], - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - }, - follow_official_norm: bool = False, - ): - """Initilize HiFiGAN multi-scale discriminator module. - - Args: - scales (int): Number of multi-scales. - downsample_pooling (str): Pooling module name for downsampling of the - inputs. - downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling - module. - discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale - discriminator module. - follow_official_norm (bool): Whether to follow the norm setting of the - official implementaion. The first discriminator uses spectral norm - and the other discriminators use weight norm. - - """ - super().__init__() - self.discriminators = torch.nn.ModuleList() - - # add discriminators - for i in range(scales): - params = copy.deepcopy(discriminator_params) - if follow_official_norm: - if i == 0: - params["use_weight_norm"] = False - params["use_spectral_norm"] = True - else: - params["use_weight_norm"] = True - params["use_spectral_norm"] = False - self.discriminators += [HiFiGANScaleDiscriminator(**params)] - self.pooling = None - if scales > 1: - self.pooling = getattr(torch.nn, downsample_pooling)( - **downsample_pooling_params - ) - - def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input noise signal (B, 1, T). - - Returns: - List[List[torch.Tensor]]: List of list of each discriminator outputs, - which consists of eachlayer output tensors. - - """ - outs = [] - for f in self.discriminators: - outs += [f(x)] - if self.pooling is not None: - x = self.pooling(x) - - return outs - - -class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): - """HiFi-GAN multi-scale + multi-period discriminator module.""" - - def __init__( - self, - # Multi-scale discriminator related - scales: int = 3, - scale_downsample_pooling: str = "AvgPool1d", - scale_downsample_pooling_params: Dict[str, Any] = { - "kernel_size": 4, - "stride": 2, - "padding": 2, - }, - scale_discriminator_params: Dict[str, Any] = { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [15, 41, 5, 3], - "channels": 128, - "max_downsample_channels": 1024, - "max_groups": 16, - "bias": True, - "downsample_scales": [2, 2, 4, 4, 1], - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - }, - follow_official_norm: bool = True, - # Multi-period discriminator related - periods: List[int] = [2, 3, 5, 7, 11], - period_discriminator_params: Dict[str, Any] = { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [5, 3], - "channels": 32, - "downsample_scales": [3, 3, 3, 3, 1], - "max_downsample_channels": 1024, - "bias": True, - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - "use_weight_norm": True, - "use_spectral_norm": False, - }, - ): - """Initilize HiFiGAN multi-scale + multi-period discriminator module. - - Args: - scales (int): Number of multi-scales. - scale_downsample_pooling (str): Pooling module name for downsampling of the - inputs. - scale_downsample_pooling_params (dict): Parameters for the above pooling - module. - scale_discriminator_params (dict): Parameters for hifi-gan scale - discriminator module. - follow_official_norm (bool): Whether to follow the norm setting of the - official implementaion. The first discriminator uses spectral norm and - the other discriminators use weight norm. - periods (list): List of periods. - period_discriminator_params (dict): Parameters for hifi-gan period - discriminator module. The period parameter will be overwritten. - - """ - super().__init__() - self.msd = HiFiGANMultiScaleDiscriminator( - scales=scales, - downsample_pooling=scale_downsample_pooling, - downsample_pooling_params=scale_downsample_pooling_params, - discriminator_params=scale_discriminator_params, - follow_official_norm=follow_official_norm, - ) - self.mpd = HiFiGANMultiPeriodDiscriminator( - periods=periods, - discriminator_params=period_discriminator_params, - ) - - def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input noise signal (B, 1, T). - - Returns: - List[List[Tensor]]: List of list of each discriminator outputs, - which consists of each layer output tensors. Multi scale and - multi period ones are concatenated. - - """ - msd_outs = self.msd(x) - mpd_outs = self.mpd(x) - return msd_outs + mpd_outs diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py deleted file mode 100755 index cf1067dfc..000000000 --- a/egs/ljspeech/TTS/vits/infer.py +++ /dev/null @@ -1,256 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script performs model inference on test set. - -Usage: -./vits/infer.py \ - --epoch 1000 \ - --exp-dir ./vits/exp \ - --max-duration 500 -""" - - -import argparse -import logging -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path -from typing import List - -import k2 -import torch -import torch.nn as nn -import torchaudio -from tokenizer import Tokenizer -from train import get_model, get_params -from tts_datamodule import LJSpeechTtsDataModule - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, setup_logger - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--model-type", - type=str, - default="high", - choices=["low", "medium", "high"], - help="""If not empty, valid values are: low, medium, high. - It controls the model size. low -> runs faster. - """, - ) - - return parser - - -def infer_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - tokenizer: Tokenizer, -) -> None: - """Decode dataset. - The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - tokenizer: - Used to convert text to phonemes. - """ - - # Background worker save audios to disk. - def _save_worker( - batch_size: int, - cut_ids: List[str], - audio: torch.Tensor, - audio_pred: torch.Tensor, - audio_lens: List[int], - audio_lens_pred: List[int], - ): - for i in range(batch_size): - torchaudio.save( - str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), - audio[i : i + 1, : audio_lens[i]], - sample_rate=params.sampling_rate, - ) - torchaudio.save( - str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), - audio_pred[i : i + 1, : audio_lens_pred[i]], - sample_rate=params.sampling_rate, - ) - - device = next(model.parameters()).device - num_cuts = 0 - log_interval = 5 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - futures = [] - with ThreadPoolExecutor(max_workers=1) as executor: - for batch_idx, batch in enumerate(dl): - batch_size = len(batch["tokens"]) - - tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - audio = batch["audio"] - audio_lens = batch["audio_lens"].tolist() - cut_ids = [cut.id for cut in batch["cut"]] - - audio_pred, _, durations = model.inference_batch( - text=tokens, text_lengths=tokens_lens - ) - audio_pred = audio_pred.detach().cpu() - # convert to samples - audio_lens_pred = ( - (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() - ) - - futures.append( - executor.submit( - _save_worker, - batch_size, - cut_ids, - audio, - audio_pred, - audio_lens, - audio_lens_pred, - ) - ) - - num_cuts += batch_size - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - # return results - for f in futures: - f.result() - - -@torch.no_grad() -def main(): - parser = get_parser() - LJSpeechTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.suffix = f"epoch-{params.epoch}" - - params.res_dir = params.exp_dir / "infer" / params.suffix - params.save_wav_dir = params.res_dir / "wav" - params.save_wav_dir.mkdir(parents=True, exist_ok=True) - - setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") - logging.info("Infer started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - logging.info(f"Device: {device}") - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - model.to(device) - model.eval() - - num_param_g = sum([p.numel() for p in model.generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - # we need cut ids to organize tts results. - args.return_cuts = True - ljspeech = LJSpeechTtsDataModule(args) - - test_cuts = ljspeech.test_cuts() - test_dl = ljspeech.test_dataloaders(test_cuts) - - infer_dataset( - dl=test_dl, - params=params, - model=model, - tokenizer=tokenizer, - ) - - logging.info(f"Wav files are saved to {params.save_wav_dir}") - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py deleted file mode 100644 index 2f4dc9bc0..000000000 --- a/egs/ljspeech/TTS/vits/loss.py +++ /dev/null @@ -1,335 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""HiFiGAN-related loss modules. - -This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. - -""" - -from typing import List, Tuple, Union - -import torch -import torch.distributions as D -import torch.nn.functional as F -from lhotse.features.kaldi import Wav2LogFilterBank - - -class GeneratorAdversarialLoss(torch.nn.Module): - """Generator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators: bool = True, - loss_type: str = "mse", - ): - """Initialize GeneratorAversarialLoss module. - - Args: - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - loss_type (str): Loss type, "mse" or "hinge". - - """ - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." - if loss_type == "mse": - self.criterion = self._mse_loss - else: - self.criterion = self._hinge_loss - - def forward( - self, - outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - ) -> torch.Tensor: - """Calcualate generator adversarial loss. - - Args: - outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs.. - - Returns: - Tensor: Generator adversarial loss value. - - """ - if isinstance(outputs, (tuple, list)): - adv_loss = 0.0 - for i, outputs_ in enumerate(outputs): - if isinstance(outputs_, (tuple, list)): - # NOTE(kan-bayashi): case including feature maps - outputs_ = outputs_[-1] - adv_loss += self.criterion(outputs_) - if self.average_by_discriminators: - adv_loss /= i + 1 - else: - adv_loss = self.criterion(outputs) - - return adv_loss - - def _mse_loss(self, x): - return F.mse_loss(x, x.new_ones(x.size())) - - def _hinge_loss(self, x): - return -x.mean() - - -class DiscriminatorAdversarialLoss(torch.nn.Module): - """Discriminator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators: bool = True, - loss_type: str = "mse", - ): - """Initialize DiscriminatorAversarialLoss module. - - Args: - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - loss_type (str): Loss type, "mse" or "hinge". - - """ - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." - if loss_type == "mse": - self.fake_criterion = self._mse_fake_loss - self.real_criterion = self._mse_real_loss - else: - self.fake_criterion = self._hinge_fake_loss - self.real_criterion = self._hinge_real_loss - - def forward( - self, - outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calcualate discriminator adversarial loss. - - Args: - outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs calculated from generator. - outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs calculated from groundtruth. - - Returns: - Tensor: Discriminator real loss value. - Tensor: Discriminator fake loss value. - - """ - if isinstance(outputs, (tuple, list)): - real_loss = 0.0 - fake_loss = 0.0 - for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): - if isinstance(outputs_hat_, (tuple, list)): - # NOTE(kan-bayashi): case including feature maps - outputs_hat_ = outputs_hat_[-1] - outputs_ = outputs_[-1] - real_loss += self.real_criterion(outputs_) - fake_loss += self.fake_criterion(outputs_hat_) - if self.average_by_discriminators: - fake_loss /= i + 1 - real_loss /= i + 1 - else: - real_loss = self.real_criterion(outputs) - fake_loss = self.fake_criterion(outputs_hat) - - return real_loss, fake_loss - - def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, x.new_ones(x.size())) - - def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, x.new_zeros(x.size())) - - def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) - - def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) - - -class FeatureMatchLoss(torch.nn.Module): - """Feature matching loss module.""" - - def __init__( - self, - average_by_layers: bool = True, - average_by_discriminators: bool = True, - include_final_outputs: bool = False, - ): - """Initialize FeatureMatchLoss module. - - Args: - average_by_layers (bool): Whether to average the loss by the number - of layers. - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - include_final_outputs (bool): Whether to include the final output of - each discriminator for loss calculation. - - """ - super().__init__() - self.average_by_layers = average_by_layers - self.average_by_discriminators = average_by_discriminators - self.include_final_outputs = include_final_outputs - - def forward( - self, - feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], - feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], - ) -> torch.Tensor: - """Calculate feature matching loss. - - Args: - feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of - discriminator outputs or list of discriminator outputs calcuated - from generator's outputs. - feats (Union[List[List[Tensor]], List[Tensor]]): List of list of - discriminator outputs or list of discriminator outputs calcuated - from groundtruth.. - - Returns: - Tensor: Feature matching loss value. - - """ - feat_match_loss = 0.0 - for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): - feat_match_loss_ = 0.0 - if not self.include_final_outputs: - feats_hat_ = feats_hat_[:-1] - feats_ = feats_[:-1] - for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): - feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) - if self.average_by_layers: - feat_match_loss_ /= j + 1 - feat_match_loss += feat_match_loss_ - if self.average_by_discriminators: - feat_match_loss /= i + 1 - - return feat_match_loss - - -class MelSpectrogramLoss(torch.nn.Module): - """Mel-spectrogram loss.""" - - def __init__( - self, - sampling_rate: int = 22050, - frame_length: int = 1024, # in samples - frame_shift: int = 256, # in samples - n_mels: int = 80, - use_fft_mag: bool = True, - ): - super().__init__() - self.wav_to_mel = Wav2LogFilterBank( - sampling_rate=sampling_rate, - frame_length=frame_length / sampling_rate, # in second - frame_shift=frame_shift / sampling_rate, # in second - use_fft_mag=use_fft_mag, - num_filters=n_mels, - ) - - def forward( - self, - y_hat: torch.Tensor, - y: torch.Tensor, - return_mel: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: - """Calculate Mel-spectrogram loss. - - Args: - y_hat (Tensor): Generated waveform tensor (B, 1, T). - y (Tensor): Groundtruth waveform tensor (B, 1, T). - spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor - (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth - waveform. - - Returns: - Tensor: Mel-spectrogram loss value. - - """ - mel_hat = self.wav_to_mel(y_hat.squeeze(1)) - mel = self.wav_to_mel(y.squeeze(1)) - mel_loss = F.l1_loss(mel_hat, mel) - - if return_mel: - return mel_loss, (mel_hat, mel) - - return mel_loss - - -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py - -"""VITS-related loss modules. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - - -class KLDivergenceLoss(torch.nn.Module): - """KL divergence loss.""" - - def forward( - self, - z_p: torch.Tensor, - logs_q: torch.Tensor, - m_p: torch.Tensor, - logs_p: torch.Tensor, - z_mask: torch.Tensor, - ) -> torch.Tensor: - """Calculate KL divergence loss. - - Args: - z_p (Tensor): Flow hidden representation (B, H, T_feats). - logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). - m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). - logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). - z_mask (Tensor): Mask tensor (B, 1, T_feats). - - Returns: - Tensor: KL divergence loss. - - """ - z_p = z_p.float() - logs_q = logs_q.float() - m_p = m_p.float() - logs_p = logs_p.float() - z_mask = z_mask.float() - kl = logs_p - logs_q - 0.5 - kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) - kl = torch.sum(kl * z_mask) - loss = kl / torch.sum(z_mask) - - return loss - - -class KLDivergenceLossWithoutFlow(torch.nn.Module): - """KL divergence loss without flow.""" - - def forward( - self, - m_q: torch.Tensor, - logs_q: torch.Tensor, - m_p: torch.Tensor, - logs_p: torch.Tensor, - ) -> torch.Tensor: - """Calculate KL divergence loss without flow. - - Args: - m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). - logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). - m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). - logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). - """ - posterior_norm = D.Normal(m_q, torch.exp(logs_q)) - prior_norm = D.Normal(m_p, torch.exp(logs_p)) - loss = D.kl_divergence(posterior_norm, prior_norm).mean() - return loss diff --git a/egs/ljspeech/TTS/vits/monotonic_align/.gitignore b/egs/ljspeech/TTS/vits/monotonic_align/.gitignore deleted file mode 100644 index 3def4ae26..000000000 --- a/egs/ljspeech/TTS/vits/monotonic_align/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -build -core.c -*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py deleted file mode 100644 index 5dc3641e5..000000000 --- a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py +++ /dev/null @@ -1,85 +0,0 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py - -"""Maximum path calculation module. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - -import warnings - -import numpy as np -import torch - -try: - from numba import njit, prange -except ModuleNotFoundError as ex: - raise RuntimeError(f"{ex}/nPlease run\n pip install numba") - -try: - from .core import maximum_path_c - - is_cython_avalable = True -except ImportError: - is_cython_avalable = False - warnings.warn( - "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " - "If you want to use the cython version, please build it as follows: " - "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`" - ) - - -def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: - """Calculate maximum path. - - Args: - neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). - attn_mask (Tensor): Attention mask (B, T_feats, T_text). - - Returns: - Tensor: Maximum path tensor (B, T_feats, T_text). - - """ - device, dtype = neg_x_ent.device, neg_x_ent.dtype - neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) - path = np.zeros(neg_x_ent.shape, dtype=np.int32) - t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) - t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) - if is_cython_avalable: - maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) - else: - maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) - - return torch.from_numpy(path).to(device=device, dtype=dtype) - - -@njit -def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): - """Calculate a single maximum path with numba.""" - index = t_x - 1 - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[y - 1, x] - if x == 0: - if y == 0: - v_prev = 0.0 - else: - v_prev = max_neg_val - else: - v_prev = value[y - 1, x - 1] - value[y, x] += max(v_prev, v_cur) - - for y in range(t_y - 1, -1, -1): - path[y, index] = 1 - if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): - index = index - 1 - - -@njit(parallel=True) -def maximum_path_numba(paths, values, t_ys, t_xs): - """Calculate batch maximum path with numba.""" - for i in prange(paths.shape[0]): - maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx deleted file mode 100644 index c02c2d02e..000000000 --- a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx +++ /dev/null @@ -1,51 +0,0 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx - -"""Maximum path calculation module with cython optimization. - -This code is copied from https://github.com/jaywalnut310/vits and modifed code format. - -""" - -cimport cython - -from cython.parallel import prange - - -@cython.boundscheck(False) -@cython.wraparound(False) -cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: - cdef int x - cdef int y - cdef float v_prev - cdef float v_cur - cdef float tmp - cdef int index = t_x - 1 - - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[y - 1, x] - if x == 0: - if y == 0: - v_prev = 0.0 - else: - v_prev = max_neg_val - else: - v_prev = value[y - 1, x - 1] - value[y, x] += max(v_prev, v_cur) - - for y in range(t_y - 1, -1, -1): - path[y, index] = 1 - if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): - index = index - 1 - - -@cython.boundscheck(False) -@cython.wraparound(False) -cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: - cdef int b = paths.shape[0] - cdef int i - for i in prange(b, nogil=True): - maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py deleted file mode 100644 index 33d75e176..000000000 --- a/egs/ljspeech/TTS/vits/monotonic_align/setup.py +++ /dev/null @@ -1,31 +0,0 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py -"""Setup cython code.""" - -from Cython.Build import cythonize -from setuptools import Extension, setup -from setuptools.command.build_ext import build_ext as _build_ext - - -class build_ext(_build_ext): - """Overwrite build_ext.""" - - def finalize_options(self): - """Prevent numpy from thinking it is still in its setup process.""" - _build_ext.finalize_options(self) - __builtins__.__NUMPY_SETUP__ = False - import numpy - - self.include_dirs.append(numpy.get_include()) - - -exts = [ - Extension( - name="core", - sources=["core.pyx"], - ) -] -setup( - name="monotonic_align", - ext_modules=cythonize(exts, language_level=3), - cmdclass={"build_ext": build_ext}, -) diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py deleted file mode 100644 index 1104fb864..000000000 --- a/egs/ljspeech/TTS/vits/posterior_encoder.py +++ /dev/null @@ -1,117 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Posterior encoder module in VITS. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - -from typing import Optional, Tuple - -import torch -from wavenet import Conv1d, WaveNet - -from icefall.utils import make_pad_mask - - -class PosteriorEncoder(torch.nn.Module): - """Posterior encoder module in VITS. - - This is a module of posterior encoder described in `Conditional Variational - Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. - - .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End - Text-to-Speech`: https://arxiv.org/abs/2006.04558 - """ - - def __init__( - self, - in_channels: int = 513, - out_channels: int = 192, - hidden_channels: int = 192, - kernel_size: int = 5, - layers: int = 16, - stacks: int = 1, - base_dilation: int = 1, - global_channels: int = -1, - dropout_rate: float = 0.0, - bias: bool = True, - use_weight_norm: bool = True, - ): - """Initilialize PosteriorEncoder module. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - hidden_channels (int): Number of hidden channels. - kernel_size (int): Kernel size in WaveNet. - layers (int): Number of layers of WaveNet. - stacks (int): Number of repeat stacking of WaveNet. - base_dilation (int): Base dilation factor. - global_channels (int): Number of global conditioning channels. - dropout_rate (float): Dropout rate. - bias (bool): Whether to use bias parameters in conv. - use_weight_norm (bool): Whether to apply weight norm. - - """ - super().__init__() - - # define modules - self.input_conv = Conv1d(in_channels, hidden_channels, 1) - self.encoder = WaveNet( - in_channels=-1, - out_channels=-1, - kernel_size=kernel_size, - layers=layers, - stacks=stacks, - base_dilation=base_dilation, - residual_channels=hidden_channels, - aux_channels=-1, - gate_channels=hidden_channels * 2, - skip_channels=hidden_channels, - global_channels=global_channels, - dropout_rate=dropout_rate, - bias=bias, - use_weight_norm=use_weight_norm, - use_first_conv=False, - use_last_conv=False, - scale_residual=False, - scale_skip_connect=True, - ) - self.proj = Conv1d(hidden_channels, out_channels * 2, 1) - - def forward( - self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, in_channels, T_feats). - x_lengths (Tensor): Length tensor (B,). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - - Returns: - Tensor: Encoded hidden representation tensor (B, out_channels, T_feats). - Tensor: Projected mean tensor (B, out_channels, T_feats). - Tensor: Projected scale tensor (B, out_channels, T_feats). - Tensor: Mask tensor for input tensor (B, 1, T_feats). - - """ - x_mask = ( - (~make_pad_mask(x_lengths)) - .unsqueeze(1) - .to( - dtype=x.dtype, - device=x.device, - ) - ) - x = self.input_conv(x) * x_mask - x = self.encoder(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = stats.split(stats.size(1) // 2, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - - return z, m, logs, x_mask diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py deleted file mode 100644 index f9a2a3786..000000000 --- a/egs/ljspeech/TTS/vits/residual_coupling.py +++ /dev/null @@ -1,228 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Residual affine coupling modules in VITS. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - -from typing import Optional, Tuple, Union - -import torch -from flow import FlipFlow -from wavenet import WaveNet - - -class ResidualAffineCouplingBlock(torch.nn.Module): - """Residual affine coupling block module. - - This is a module of residual affine coupling block, which used as "Flow" in - `Conditional Variational Autoencoder with Adversarial Learning for End-to-End - Text-to-Speech`_. - - .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End - Text-to-Speech`: https://arxiv.org/abs/2006.04558 - - """ - - def __init__( - self, - in_channels: int = 192, - hidden_channels: int = 192, - flows: int = 4, - kernel_size: int = 5, - base_dilation: int = 1, - layers: int = 4, - global_channels: int = -1, - dropout_rate: float = 0.0, - use_weight_norm: bool = True, - bias: bool = True, - use_only_mean: bool = True, - ): - """Initilize ResidualAffineCouplingBlock module. - - Args: - in_channels (int): Number of input channels. - hidden_channels (int): Number of hidden channels. - flows (int): Number of flows. - kernel_size (int): Kernel size for WaveNet. - base_dilation (int): Base dilation factor for WaveNet. - layers (int): Number of layers of WaveNet. - stacks (int): Number of stacks of WaveNet. - global_channels (int): Number of global channels. - dropout_rate (float): Dropout rate. - use_weight_norm (bool): Whether to use weight normalization in WaveNet. - bias (bool): Whether to use bias paramters in WaveNet. - use_only_mean (bool): Whether to estimate only mean. - - """ - super().__init__() - - self.flows = torch.nn.ModuleList() - for i in range(flows): - self.flows += [ - ResidualAffineCouplingLayer( - in_channels=in_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - base_dilation=base_dilation, - layers=layers, - stacks=1, - global_channels=global_channels, - dropout_rate=dropout_rate, - use_weight_norm=use_weight_norm, - bias=bias, - use_only_mean=use_only_mean, - ) - ] - self.flows += [FlipFlow()] - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - g: Optional[torch.Tensor] = None, - inverse: bool = False, - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, in_channels, T). - x_lengths (Tensor): Length tensor (B,). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Output tensor (B, in_channels, T). - - """ - if not inverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, inverse=inverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, inverse=inverse) - return x - - -class ResidualAffineCouplingLayer(torch.nn.Module): - """Residual affine coupling layer.""" - - def __init__( - self, - in_channels: int = 192, - hidden_channels: int = 192, - kernel_size: int = 5, - base_dilation: int = 1, - layers: int = 5, - stacks: int = 1, - global_channels: int = -1, - dropout_rate: float = 0.0, - use_weight_norm: bool = True, - bias: bool = True, - use_only_mean: bool = True, - ): - """Initialzie ResidualAffineCouplingLayer module. - - Args: - in_channels (int): Number of input channels. - hidden_channels (int): Number of hidden channels. - kernel_size (int): Kernel size for WaveNet. - base_dilation (int): Base dilation factor for WaveNet. - layers (int): Number of layers of WaveNet. - stacks (int): Number of stacks of WaveNet. - global_channels (int): Number of global channels. - dropout_rate (float): Dropout rate. - use_weight_norm (bool): Whether to use weight normalization in WaveNet. - bias (bool): Whether to use bias paramters in WaveNet. - use_only_mean (bool): Whether to estimate only mean. - - """ - assert in_channels % 2 == 0, "in_channels should be divisible by 2" - super().__init__() - self.half_channels = in_channels // 2 - self.use_only_mean = use_only_mean - - # define modules - self.input_conv = torch.nn.Conv1d( - self.half_channels, - hidden_channels, - 1, - ) - self.encoder = WaveNet( - in_channels=-1, - out_channels=-1, - kernel_size=kernel_size, - layers=layers, - stacks=stacks, - base_dilation=base_dilation, - residual_channels=hidden_channels, - aux_channels=-1, - gate_channels=hidden_channels * 2, - skip_channels=hidden_channels, - global_channels=global_channels, - dropout_rate=dropout_rate, - bias=bias, - use_weight_norm=use_weight_norm, - use_first_conv=False, - use_last_conv=False, - scale_residual=False, - scale_skip_connect=True, - ) - if use_only_mean: - self.proj = torch.nn.Conv1d( - hidden_channels, - self.half_channels, - 1, - ) - else: - self.proj = torch.nn.Conv1d( - hidden_channels, - self.half_channels * 2, - 1, - ) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - g: Optional[torch.Tensor] = None, - inverse: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, in_channels, T). - x_lengths (Tensor): Length tensor (B,). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Output tensor (B, in_channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - xa, xb = x.split(x.size(1) // 2, dim=1) - h = self.input_conv(xa) * x_mask - h = self.encoder(h, x_mask, g=g) - stats = self.proj(h) * x_mask - if not self.use_only_mean: - m, logs = stats.split(stats.size(1) // 2, dim=1) - else: - m = stats - logs = torch.zeros_like(m) - - if not inverse: - xb = m + xb * torch.exp(logs) * x_mask - x = torch.cat([xa, xb], 1) - logdet = torch.sum(logs, [1, 2]) - return x, logdet - else: - xb = (xb - m) * torch.exp(-logs) * x_mask - x = torch.cat([xa, xb], 1) - return x diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py deleted file mode 100755 index 4faaa96a5..000000000 --- a/egs/ljspeech/TTS/vits/test_model.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tokenizer import Tokenizer -from train import get_model, get_params - - -def test_model_type(model_type): - tokens = "./data/tokens.txt" - - params = get_params() - - tokenizer = Tokenizer(tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_type = model_type - - model = get_model(params) - generator = model.generator - - num_param = sum([p.numel() for p in generator.parameters()]) - print( - f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M" - ) - - -def main(): - test_model_type("high") # 35.63 M - test_model_type("low") # 7.55 M - test_model_type("medium") # 23.61 M - - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py deleted file mode 100755 index b3805fadb..000000000 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to test the exported onnx model by vits/export-onnx.py - -Use the onnx model to generate a wav: -./vits/test_onnx.py \ - --model-filename vits/exp/vits-epoch-1000.onnx \ - --tokens data/tokens.txt -""" - - -import argparse -import logging - -import onnxruntime as ort -import torch -import torchaudio -from tokenizer import Tokenizer - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the onnx model.", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--text", - type=str, - default="Ask not what your country can do for you; ask what you can do for your country.", - help="Text to generate speech for", - ) - - parser.add_argument( - "--output-filename", - type=str, - default="test_onnx.wav", - help="Filename to save the generated wave file.", - ) - - return parser - - -class OnnxModel: - def __init__(self, model_filename: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.model = ort.InferenceSession( - model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - - metadata = self.model.get_modelmeta().custom_metadata_map - self.sample_rate = int(metadata["sample_rate"]) - - def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: - """ - Args: - tokens: - A 1-D tensor of shape (1, T) - Returns: - A tensor of shape (1, T') - """ - noise_scale = torch.tensor([0.667], dtype=torch.float32) - noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) - alpha = torch.tensor([1.0], dtype=torch.float32) - - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: tokens.numpy(), - self.model.get_inputs()[1].name: tokens_lens.numpy(), - self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: alpha.numpy(), - self.model.get_inputs()[4].name: noise_scale_dur.numpy(), - }, - )[0] - return torch.from_numpy(out) - - -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - tokenizer = Tokenizer(args.tokens) - - logging.info("About to create onnx model") - model = OnnxModel(args.model_filename) - - text = args.text - tokens = tokenizer.texts_to_token_ids( - [text], intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = torch.tensor(tokens) # (1, T) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) - audio = model(tokens, tokens_lens) # (1, T') - - output_filename = args.output_filename - torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) - logging.info(f"Saved to {output_filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py deleted file mode 100644 index 9b21ed9cb..000000000 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ /dev/null @@ -1,685 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Text encoder module in VITS. - -This code is based on - - https://github.com/jaywalnut310/vits - - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py - - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py -""" - -import copy -import math -from typing import Optional, Tuple - -import torch -from torch import Tensor, nn - -from icefall.utils import is_jit_tracing, make_pad_mask - - -class TextEncoder(torch.nn.Module): - """Text encoder module in VITS. - - This is a module of text encoder described in `Conditional Variational Autoencoder - with Adversarial Learning for End-to-End Text-to-Speech`. - """ - - def __init__( - self, - vocabs: int, - d_model: int = 192, - num_heads: int = 2, - dim_feedforward: int = 768, - cnn_module_kernel: int = 5, - num_layers: int = 6, - dropout: float = 0.1, - ): - """Initialize TextEncoder module. - - Args: - vocabs (int): Vocabulary size. - d_model (int): attention dimension - num_heads (int): number of attention heads - dim_feedforward (int): feedforward dimention - cnn_module_kernel (int): convolution kernel size - num_layers (int): number of encoder layers - dropout (float): dropout rate - """ - super().__init__() - self.d_model = d_model - - # define modules - self.emb = torch.nn.Embedding(vocabs, d_model) - torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) - - # We use conformer as text encoder - self.encoder = Transformer( - d_model=d_model, - num_heads=num_heads, - dim_feedforward=dim_feedforward, - cnn_module_kernel=cnn_module_kernel, - num_layers=num_layers, - dropout=dropout, - ) - - self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1) - - def forward( - self, - x: torch.Tensor, - x_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - x (Tensor): Input index tensor (B, T_text). - x_lengths (Tensor): Length tensor (B,). - - Returns: - Tensor: Encoded hidden representation (B, embed_dim, T_text). - Tensor: Projected mean tensor (B, embed_dim, T_text). - Tensor: Projected scale tensor (B, embed_dim, T_text). - Tensor: Mask tensor for input tensor (B, 1, T_text). - - """ - # (B, T_text, embed_dim) - x = self.emb(x) * math.sqrt(self.d_model) - - assert x.size(1) == x_lengths.max().item() - - # (B, T_text) - pad_mask = make_pad_mask(x_lengths) - - # encoder assume the channel last (B, T_text, embed_dim) - x = self.encoder(x, key_padding_mask=pad_mask) - # Note: attention_dim == embed_dim - - # convert the channel first (B, embed_dim, T_text) - x = x.transpose(1, 2) - non_pad_mask = (~pad_mask).unsqueeze(1) - stats = self.proj(x) * non_pad_mask - m, logs = stats.split(stats.size(1) // 2, dim=1) - - return x, m, logs, non_pad_mask - - -class Transformer(nn.Module): - """ - Args: - d_model (int): attention dimension - num_heads (int): number of attention heads - dim_feedforward (int): feedforward dimention - cnn_module_kernel (int): convolution kernel size - num_layers (int): number of encoder layers - dropout (float): dropout rate - """ - - def __init__( - self, - d_model: int = 192, - num_heads: int = 2, - dim_feedforward: int = 768, - cnn_module_kernel: int = 5, - num_layers: int = 6, - dropout: float = 0.1, - ) -> None: - super().__init__() - - self.num_layers = num_layers - self.d_model = d_model - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - num_heads=num_heads, - dim_feedforward=dim_feedforward, - cnn_module_kernel=cnn_module_kernel, - dropout=dropout, - ) - self.encoder = TransformerEncoder(encoder_layer, num_layers) - self.after_norm = nn.LayerNorm(d_model) - - def forward( - self, x: Tensor, key_padding_mask: Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - lengths: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - """ - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) - - x = self.after_norm(x) - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x - - -class TransformerEncoderLayer(nn.Module): - """ - TransformerEncoderLayer is made up of self-attn and feedforward. - - Args: - d_model: the number of expected features in the input. - num_heads: the number of heads in the multi-head attention models. - dim_feedforward: the dimension of the feed-forward network model. - dropout: the dropout value (default=0.1). - """ - - def __init__( - self, - d_model: int, - num_heads: int, - dim_feedforward: int, - cnn_module_kernel: int, - dropout: float = 0.1, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.self_attn = RelPositionMultiheadAttention( - d_model, num_heads, dropout=dropout - ) - - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - - self.ff_scale = 0.5 - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the transformer encoder layer. - - Args: - src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). - pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). - key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) - """ - # macaron style feed-forward module - src = src + self.ff_scale * self.dropout( - self.feed_forward_macaron(self.norm_ff_macaron(src)) - ) - - # multi-head self-attention module - src_attn = self.self_attn( - self.norm_mha(src), - pos_emb=pos_emb, - key_padding_mask=key_padding_mask, - ) - src = src + self.dropout(src_attn) - - # convolution module - src = src + self.dropout(self.conv_module(self.norm_conv(src))) - - # feed-forward module - src = src + self.dropout(self.feed_forward(self.norm_ff(src))) - - src = self.norm_final(src) - - return src - - -class TransformerEncoder(nn.Module): - r"""TransformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the TransformerEncoderLayer class. - num_layers: the number of sub-encoder-layers in the encoder. - """ - - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: - super().__init__() - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). - pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). - key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) - """ - output = src - - for layer_index, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - key_padding_mask=key_padding_mask, - ) - - return output - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - x_size = x.size(1) - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - """ - self.extend_pe(x) - x = x * self.xscale - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, seq_len, 2*seq_len-1). - - Returns: - Tensor: tensor of shape (batch, head, seq_len, seq_len) - """ - (batch_size, num_heads, seq_len, n) = x.shape - - if not is_jit_tracing(): - assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" - - if is_jit_tracing(): - rows = torch.arange(start=seq_len - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - - x = x.reshape(-1, n) - x = torch.gather(x, dim=1, index=indexes) - x = x.reshape(batch_size, num_heads, seq_len, seq_len) - return x - else: - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, seq_len, seq_len), - (batch_stride, head_stride, time_stride - n_stride, n_stride), - storage_offset=n_stride * (seq_len - 1), - ) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: Input tensor of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim) - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - Its shape is (batch_size, seq_len). - - Outputs: - A tensor of shape (seq_len, batch_size, embed_dim). - """ - seq_len, batch_size, _ = x.shape - scaling = float(self.head_dim) ** -0.5 - - q, k, v = self.in_proj(x).chunk(3, dim=-1) - - q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) - k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) - v = ( - v.contiguous() - .view(seq_len, batch_size * self.num_heads, self.head_dim) - .transpose(0, 1) - ) - - q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) - - p = self.linear_pos(pos_emb).view( - pos_emb.size(0), -1, self.num_heads, self.head_dim - ) - # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) - p = p.permute(0, 2, 3, 1) - - # (batch_size, num_head, seq_len, head_dim) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch_size, num_head, seq_len, seq_len) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch_size, num_head, seq_len, 2*seq_len-1) - matrix_bd = self.rel_shift( - matrix_bd - ) # (batch_size, num_head, seq_len, seq_len) - - # (batch_size, num_head, seq_len, seq_len) - attn_output_weights = (matrix_ac + matrix_bd) * scaling - attn_output_weights = attn_output_weights.view( - batch_size * self.num_heads, seq_len, seq_len - ) - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len) - attn_output_weights = attn_output_weights.view( - batch_size, self.num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - batch_size * self.num_heads, seq_len, seq_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=self.dropout, training=self.training - ) - - # (batch_size * num_head, seq_len, head_dim) - attn_output = torch.bmm(attn_output_weights, v) - assert attn_output.shape == ( - batch_size * self.num_heads, - seq_len, - self.head_dim, - ) - - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, batch_size, self.embed_dim) - ) - # (seq_len, batch_size, embed_dim) - attn_output = self.out_proj(attn_output) - - return attn_output - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - """ - - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - padding = (kernel_size - 1) // 2 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - bias=bias, - ) - self.norm = nn.LayerNorm(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x) - # x is (batch, channels, time) - x = x.permute(0, 2, 1) - x = self.norm(x) - x = x.permute(0, 2, 1) - - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def _test_text_encoder(): - vocabs = 500 - d_model = 192 - batch_size = 5 - seq_len = 100 - - m = TextEncoder(vocabs=vocabs, d_model=d_model) - x, m, logs, mask = m( - x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)), - x_lengths=torch.full((batch_size,), seq_len), - ) - print(x.shape, m.shape, logs.shape, mask.shape) - - -if __name__ == "__main__": - _test_text_encoder() diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py deleted file mode 100644 index 3c9046add..000000000 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Dict, List - -import tacotron_cleaner.cleaners - -try: - from piper_phonemize import phonemize_espeak -except Exception as ex: - raise RuntimeError( - f"{ex}\nPlease run\n" - "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" - ) - -from utils import intersperse - - -class Tokenizer(object): - def __init__(self, tokens: str): - """ - Args: - tokens: the file that maps tokens to ids - """ - # Parse token file - self.token2id: Dict[str, int] = {} - with open(tokens, "r", encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split() - if len(info) == 1: - # case of space - token = " " - id = int(info[0]) - else: - token, id = info[0], int(info[1]) - assert token not in self.token2id, token - self.token2id[token] = id - - # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md - self.pad_id = self.token2id["_"] # padding - self.sos_id = self.token2id["^"] # beginning of an utterance (bos) - self.eos_id = self.token2id["$"] # end of an utterance (eos) - self.space_id = self.token2id[" "] # word separator (whitespace) - - self.vocab_size = len(self.token2id) - - def texts_to_token_ids( - self, - texts: List[str], - intersperse_blank: bool = True, - add_sos: bool = False, - add_eos: bool = False, - lang: str = "en-us", - ) -> List[List[int]]: - """ - Args: - texts: - A list of transcripts. - intersperse_blank: - Whether to intersperse blanks in the token sequence. - add_sos: - Whether to add sos token at the start. - add_eos: - Whether to add eos token at the end. - lang: - Language argument passed to phonemize_espeak(). - - Returns: - Return a list of token id list [utterance][token_id] - """ - token_ids_list = [] - - for text in texts: - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens_list = phonemize_espeak(text, lang) - tokens = [] - for t in tokens_list: - tokens.extend(t) - - token_ids = [] - for t in tokens: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - token_ids.append(self.token2id[t]) - - if intersperse_blank: - token_ids = intersperse(token_ids, self.pad_id) - if add_sos: - token_ids = [self.sos_id] + token_ids - if add_eos: - token_ids = token_ids + [self.eos_id] - - token_ids_list.append(token_ids) - - return token_ids_list - - def tokens_to_token_ids( - self, - tokens_list: List[str], - intersperse_blank: bool = True, - add_sos: bool = False, - add_eos: bool = False, - ) -> List[List[int]]: - """ - Args: - tokens_list: - A list of token list, each corresponding to one utterance. - intersperse_blank: - Whether to intersperse blanks in the token sequence. - add_sos: - Whether to add sos token at the start. - add_eos: - Whether to add eos token at the end. - - Returns: - Return a list of token id list [utterance][token_id] - """ - token_ids_list = [] - - for tokens in tokens_list: - token_ids = [] - for t in tokens: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - token_ids.append(self.token2id[t]) - - if intersperse_blank: - token_ids = intersperse(token_ids, self.pad_id) - if add_sos: - token_ids = [self.sos_id] + token_ids - if add_eos: - token_ids = token_ids + [self.eos_id] - - token_ids_list.append(token_ids) - - return token_ids_list diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py deleted file mode 100755 index 184ae79af..000000000 --- a/egs/ljspeech/TTS/vits/train.py +++ /dev/null @@ -1,929 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import numpy as np -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import LJSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, setup_logger, str2bool - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1000, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--lr", type=float, default=2.0e-4, help="The base learning rate." - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=20, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--model-type", - type=str, - default="high", - choices=["low", "medium", "high"], - help="""If not empty, valid values are: low, medium, high. - It controls the model size. low -> runs faster. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - """ - params = AttributeDict( - { - # training params - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 200, - "env_info": get_env_info(), - "sampling_rate": 22050, - "frame_shift": 256, - "frame_length": 1024, - "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "n_mels": 80, - "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss - "lambda_mel": 45.0, # loss scaling coefficient for Mel loss - "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss - "lambda_dur": 1.0, # loss scaling coefficient for duration loss - "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def get_model(params: AttributeDict) -> nn.Module: - mel_loss_params = { - "n_mels": params.n_mels, - "frame_length": params.frame_length, - "frame_shift": params.frame_shift, - } - model = VITS( - vocab_size=params.vocab_size, - feature_dim=params.feature_dim, - sampling_rate=params.sampling_rate, - model_type=params.model_type, - mel_loss_params=mel_loss_params, - lambda_adv=params.lambda_adv, - lambda_mel=params.lambda_mel, - lambda_feat_match=params.lambda_feat_match, - lambda_dur=params.lambda_dur, - lambda_kl=params.lambda_kl, - ) - return model - - -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): - """Parse batch data""" - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - tokens = batch["tokens"] - - tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - return audio, audio_lens, features, features_lens, tokens, tokens_lens - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - optimizer_g: Optimizer, - optimizer_d: Optimizer, - scheduler_g: LRSchedulerType, - scheduler_d: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - tokenizer: - Used to convert text to phonemes. - optimizer_g: - The optimizer for generator. - optimizer_d: - The optimizer for discriminator. - scheduler_g: - The learning rate scheduler for generator, we call step() every epoch. - scheduler_d: - The learning rate scheduler for discriminator, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - - batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - try: - with autocast(enabled=params.use_fp16): - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - # update discriminator - optimizer_d.zero_grad() - scaler.scale(loss_d).backward() - scaler.step(optimizer_d) - - with autocast(enabled=params.use_fp16): - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - return_sample=params.batch_idx_train % params.log_interval == 0, - ) - for k, v in stats_g.items(): - if "returned_sample" not in k: - loss_info[k] = v * batch_size - # update generator - optimizer_g.zero_grad() - scaler.scale(loss_g).backward() - scaler.step(optimizer_g) - scaler.update() - - # summary stats - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr_g = max(scheduler_g.get_last_lr()) - cur_lr_d = max(scheduler_d.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate_g", cur_lr_g, params.batch_idx_train - ) - tb_writer.add_scalar( - "train/learning_rate_d", cur_lr_d, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - if "returned_sample" in stats_g: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] - tb_writer.add_audio( - "train/speech_hat_", - speech_hat_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/speech_", - speech_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_image( - "train/mel_hat_", - plot_feature(mel_hat_), - params.batch_idx_train, - dataformats="HWC", - ) - tb_writer.add_image( - "train/mel_", - plot_feature(mel_), - params.batch_idx_train, - dataformats="HWC", - ) - - if ( - params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info, (speech_hat, speech) = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - tb_writer.add_audio( - "train/valid_speech_hat", - speech_hat, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valid_speech", - speech, - params.batch_idx_train, - params.sampling_rate, - ) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, - rank: int = 0, -) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - returned_sample = None - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - assert loss_d.requires_grad is False - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - assert loss_g.requires_grad is False - for k, v in stats_g.items(): - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - # infer for first batch: - if batch_idx == 0 and rank == 0: - inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference( - text=tokens[0, : tokens_lens[0].item()] - ) - audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = ( - (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - ) - assert audio_len_pred == len(audio_pred), ( - audio_len_pred, - len(audio_pred), - ) - audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() - returned_sample = (audio_pred, audio_gt) - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss, returned_sample - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - tokenizer: Tokenizer, - optimizer_g: torch.optim.Optimizer, - optimizer_d: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) - try: - # for discriminator - with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - optimizer_d.zero_grad() - loss_d.backward() - # for generator - with autocast(enabled=params.use_fp16): - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - optimizer_g.zero_grad() - loss_g.backward() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - generator = model.generator - discriminator = model.discriminator - - num_param_g = sum([p.numel() for p in generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer_g = torch.optim.AdamW( - generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - optimizer_d = torch.optim.AdamW( - discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) - - if checkpoints is not None: - # load state_dict for optimizers - if "optimizer_g" in checkpoints: - logging.info("Loading optimizer_g state dict") - optimizer_g.load_state_dict(checkpoints["optimizer_g"]) - if "optimizer_d" in checkpoints: - logging.info("Loading optimizer_d state dict") - optimizer_d.load_state_dict(checkpoints["optimizer_d"]) - - # load state_dict for schedulers - if "scheduler_g" in checkpoints: - logging.info("Loading scheduler_g state dict") - scheduler_g.load_state_dict(checkpoints["scheduler_g"]) - if "scheduler_d" in checkpoints: - logging.info("Loading scheduler_d state dict") - scheduler_d.load_state_dict(checkpoints["scheduler_d"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - ljspeech = LJSpeechTtsDataModule(args) - - train_cuts = ljspeech.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = ljspeech.train_dataloaders(train_cuts) - - valid_cuts = ljspeech.valid_cuts() - valid_dl = ljspeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - # step per epoch - scheduler_g.step() - scheduler_d.step() - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - LJSpeechTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py deleted file mode 100644 index c20d13130..000000000 --- a/egs/ljspeech/TTS/vits/transform.py +++ /dev/null @@ -1,218 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py - -"""Flow-related transformation. - -This code is derived from https://github.com/bayesiains/nflows. - -""" - -import numpy as np -import torch -from torch.nn import functional as F - -DEFAULT_MIN_BIN_WIDTH = 1e-3 -DEFAULT_MIN_BIN_HEIGHT = 1e-3 -DEFAULT_MIN_DERIVATIVE = 1e-3 - - -# TODO(kan-bayashi): Documentation and type hint -def piecewise_rational_quadratic_transform( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if tails is None: - spline_fn = rational_quadratic_spline - spline_kwargs = {} - else: - spline_fn = unconstrained_rational_quadratic_spline - spline_kwargs = {"tails": tails, "tail_bound": tail_bound} - - outputs, logabsdet = spline_fn( - inputs=inputs, - unnormalized_widths=unnormalized_widths, - unnormalized_heights=unnormalized_heights, - unnormalized_derivatives=unnormalized_derivatives, - inverse=inverse, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - **spline_kwargs - ) - return outputs, logabsdet - - -# TODO(kan-bayashi): Documentation and type hint -def unconstrained_rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails="linear", - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) - outside_interval_mask = ~inside_interval_mask - - outputs = torch.zeros_like(inputs) - logabsdet = torch.zeros_like(inputs) - - if tails == "linear": - unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) - constant = np.log(np.exp(1 - min_derivative) - 1) - unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 - else: - raise RuntimeError("{} tails are not implemented.".format(tails)) - - ( - outputs[inside_interval_mask], - logabsdet[inside_interval_mask], - ) = rational_quadratic_spline( - inputs=inputs[inside_interval_mask], - unnormalized_widths=unnormalized_widths[inside_interval_mask, :], - unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], - inverse=inverse, - left=-tail_bound, - right=tail_bound, - bottom=-tail_bound, - top=tail_bound, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - ) - - return outputs, logabsdet - - -# TODO(kan-bayashi): Documentation and type hint -def rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - left=0.0, - right=1.0, - bottom=0.0, - top=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if torch.min(inputs) < left or torch.max(inputs) > right: - raise ValueError("Input to a transform is not within its domain") - - num_bins = unnormalized_widths.shape[-1] - - if min_bin_width * num_bins > 1.0: - raise ValueError("Minimal bin width too large for the number of bins") - if min_bin_height * num_bins > 1.0: - raise ValueError("Minimal bin height too large for the number of bins") - - widths = F.softmax(unnormalized_widths, dim=-1) - widths = min_bin_width + (1 - min_bin_width * num_bins) * widths - cumwidths = torch.cumsum(widths, dim=-1) - cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) - cumwidths = (right - left) * cumwidths + left - cumwidths[..., 0] = left - cumwidths[..., -1] = right - widths = cumwidths[..., 1:] - cumwidths[..., :-1] - - derivatives = min_derivative + F.softplus(unnormalized_derivatives) - - heights = F.softmax(unnormalized_heights, dim=-1) - heights = min_bin_height + (1 - min_bin_height * num_bins) * heights - cumheights = torch.cumsum(heights, dim=-1) - cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) - cumheights = (top - bottom) * cumheights + bottom - cumheights[..., 0] = bottom - cumheights[..., -1] = top - heights = cumheights[..., 1:] - cumheights[..., :-1] - - if inverse: - bin_idx = _searchsorted(cumheights, inputs)[..., None] - else: - bin_idx = _searchsorted(cumwidths, inputs)[..., None] - - input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] - input_bin_widths = widths.gather(-1, bin_idx)[..., 0] - - input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] - delta = heights / widths - input_delta = delta.gather(-1, bin_idx)[..., 0] - - input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] - input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] - - input_heights = heights.gather(-1, bin_idx)[..., 0] - - if inverse: - a = (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) + input_heights * (input_delta - input_derivatives) - b = input_heights * input_derivatives - (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) - c = -input_delta * (inputs - input_cumheights) - - discriminant = b.pow(2) - 4 * a * c - assert (discriminant >= 0).all() - - root = (2 * c) / (-b - torch.sqrt(discriminant)) - outputs = root * input_bin_widths + input_cumwidths - - theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, -logabsdet - else: - theta = (inputs - input_cumwidths) / input_bin_widths - theta_one_minus_theta = theta * (1 - theta) - - numerator = input_heights * ( - input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta - ) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) - outputs = input_cumheights + numerator / denominator - - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, logabsdet - - -def _searchsorted(bin_locations, inputs, eps=1e-6): - bin_locations[..., -1] += eps - return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py deleted file mode 100644 index e1a9c7b3c..000000000 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LJSpeechTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" - ) diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py deleted file mode 100644 index 6a067f596..000000000 --- a/egs/ljspeech/TTS/vits/utils.py +++ /dev/null @@ -1,265 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -from lhotse.dataset.sampling.base import CutSampler -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter - - -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py -def get_random_segments( - x: torch.Tensor, - x_lengths: torch.Tensor, - segment_size: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Get random segments. - - Args: - x (Tensor): Input tensor (B, C, T). - x_lengths (Tensor): Length tensor (B,). - segment_size (int): Segment size. - - Returns: - Tensor: Segmented tensor (B, C, segment_size). - Tensor: Start index tensor (B,). - - """ - b, c, t = x.size() - max_start_idx = x_lengths - segment_size - max_start_idx[max_start_idx < 0] = 0 - start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( - dtype=torch.long, - ) - segments = get_segments(x, start_idxs, segment_size) - - return segments, start_idxs - - -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py -def get_segments( - x: torch.Tensor, - start_idxs: torch.Tensor, - segment_size: int, -) -> torch.Tensor: - """Get segments. - - Args: - x (Tensor): Input tensor (B, C, T). - start_idxs (Tensor): Start index tensor (B,). - segment_size (int): Segment size. - - Returns: - Tensor: Segmented tensor (B, C, segment_size). - - """ - b, c, t = x.size() - segments = x.new_zeros(b, c, segment_size) - for i, start_idx in enumerate(start_idxs): - segments[i] = x[i, :, start_idx : start_idx + segment_size] - return segments - - -# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py -def intersperse(sequence, item=0): - result = [item] * (len(sequence) * 2 + 1) - result[1::2] = sequence - return result - - -# from https://github.com/jaywalnut310/vits/blob/main/utils.py -MATPLOTLIB_FLAG = False - - -def plot_feature(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger("matplotlib") - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data - - -class MetricsTracker(collections.defaultdict): - def __init__(self): - # Passing the type 'int' to the base-class constructor - # makes undefined items default to int() which is zero. - # This class will play a role as metrics tracker. - # It can record many metrics, including but not limited to loss. - super(MetricsTracker, self).__init__(int) - - def __add__(self, other: "MetricsTracker") -> "MetricsTracker": - ans = MetricsTracker() - for k, v in self.items(): - ans[k] = v - for k, v in other.items(): - ans[k] = ans[k] + v - return ans - - def __mul__(self, alpha: float) -> "MetricsTracker": - ans = MetricsTracker() - for k, v in self.items(): - ans[k] = v * alpha - return ans - - def __str__(self) -> str: - ans = "" - for k, v in self.norm_items(): - norm_value = "%.4g" % v - ans += str(k) + "=" + str(norm_value) + ", " - samples = "%.2f" % self["samples"] - ans += "over " + str(samples) + " samples." - return ans - - def norm_items(self) -> List[Tuple[str, float]]: - """ - Returns a list of pairs, like: - [('loss_1', 0.1), ('loss_2', 0.07)] - """ - samples = self["samples"] if "samples" in self else 1 - ans = [] - for k, v in self.items(): - if k == "samples": - continue - norm_value = float(v) / samples - ans.append((k, norm_value)) - return ans - - def reduce(self, device): - """ - Reduce using torch.distributed, which I believe ensures that - all processes get the total. - """ - keys = sorted(self.keys()) - s = torch.tensor([float(self[k]) for k in keys], device=device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - for k, v in zip(keys, s.cpu().tolist()): - self[k] = v - - def write_summary( - self, - tb_writer: SummaryWriter, - prefix: str, - batch_idx: int, - ) -> None: - """Add logging information to a TensorBoard writer. - - Args: - tb_writer: a TensorBoard writer - prefix: a prefix for the name of the loss, e.g. "train/valid_", - or "train/current_" - batch_idx: The current batch index, used as the x-axis of the plot. - """ - for k, v in self.norm_items(): - tb_writer.add_scalar(prefix + k, v, batch_idx) - - -# checkpoint saving and loading -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def save_checkpoint( - filename: Path, - model: Union[nn.Module, DDP], - params: Optional[Dict[str, Any]] = None, - optimizer_g: Optional[Optimizer] = None, - optimizer_d: Optional[Optimizer] = None, - scheduler_g: Optional[LRSchedulerType] = None, - scheduler_d: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, - sampler: Optional[CutSampler] = None, - rank: int = 0, -) -> None: - """Save training information to a file. - - Args: - filename: - The checkpoint filename. - model: - The model to be saved. We only save its `state_dict()`. - model_avg: - The stored model averaged from the start of training. - params: - User defined parameters, e.g., epoch, loss. - optimizer_g: - The optimizer for generator used in the training. - Its `state_dict` will be saved. - optimizer_d: - The optimizer for discriminator used in the training. - Its `state_dict` will be saved. - scheduler_g: - The learning rate scheduler for generator used in the training. - Its `state_dict` will be saved. - scheduler_d: - The learning rate scheduler for discriminator used in the training. - Its `state_dict` will be saved. - scalar: - The GradScaler to be saved. We only save its `state_dict()`. - rank: - Used in DDP. We save checkpoint only for the node whose rank is 0. - Returns: - Return None. - """ - if rank != 0: - return - - logging.info(f"Saving checkpoint to {filename}") - - if isinstance(model, DDP): - model = model.module - - checkpoint = { - "model": model.state_dict(), - "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, - "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, - "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, - "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, - "grad_scaler": scaler.state_dict() if scaler is not None else None, - "sampler": sampler.state_dict() if sampler is not None else None, - } - - if params: - for k, v in params.items(): - assert k not in checkpoint - checkpoint[k] = v - - torch.save(checkpoint, filename) diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py deleted file mode 100644 index a1fabf9ad..000000000 --- a/egs/ljspeech/TTS/vits/vits.py +++ /dev/null @@ -1,664 +0,0 @@ -# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""VITS module for GAN-TTS task.""" - -import copy -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.nn as nn -from generator import VITSGenerator -from hifigan import ( - HiFiGANMultiPeriodDiscriminator, - HiFiGANMultiScaleDiscriminator, - HiFiGANMultiScaleMultiPeriodDiscriminator, - HiFiGANPeriodDiscriminator, - HiFiGANScaleDiscriminator, -) -from loss import ( - DiscriminatorAdversarialLoss, - FeatureMatchLoss, - GeneratorAdversarialLoss, - KLDivergenceLoss, - MelSpectrogramLoss, -) -from torch.cuda.amp import autocast -from utils import get_segments - -AVAILABLE_GENERATERS = { - "vits_generator": VITSGenerator, -} -AVAILABLE_DISCRIMINATORS = { - "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, - "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, - "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, - "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, - "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA -} - -LOW_CONFIG = { - "hidden_channels": 96, - "decoder_upsample_scales": (8, 8, 4), - "decoder_channels": 256, - "decoder_upsample_kernel_sizes": (16, 16, 8), - "decoder_resblock_kernel_sizes": (3, 5, 7), - "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), - "text_encoder_cnn_module_kernel": 3, -} - -MEDIUM_CONFIG = { - "hidden_channels": 192, - "decoder_upsample_scales": (8, 8, 4), - "decoder_channels": 256, - "decoder_upsample_kernel_sizes": (16, 16, 8), - "decoder_resblock_kernel_sizes": (3, 5, 7), - "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), - "text_encoder_cnn_module_kernel": 3, -} - -HIGH_CONFIG = { - "hidden_channels": 192, - "decoder_upsample_scales": (8, 8, 2, 2), - "decoder_channels": 512, - "decoder_upsample_kernel_sizes": (16, 16, 4, 4), - "decoder_resblock_kernel_sizes": (3, 7, 11), - "decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)), - "text_encoder_cnn_module_kernel": 5, -} - - -class VITS(nn.Module): - """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" - - def __init__( - self, - # generator related - vocab_size: int, - feature_dim: int = 513, - sampling_rate: int = 22050, - generator_type: str = "vits_generator", - model_type: str = "", - generator_params: Dict[str, Any] = { - "hidden_channels": 192, - "spks": None, - "langs": None, - "spk_embed_dim": None, - "global_channels": -1, - "segment_size": 32, - "text_encoder_attention_heads": 2, - "text_encoder_ffn_expand": 4, - "text_encoder_cnn_module_kernel": 5, - "text_encoder_blocks": 6, - "text_encoder_dropout_rate": 0.1, - "decoder_kernel_size": 7, - "decoder_channels": 512, - "decoder_upsample_scales": [8, 8, 2, 2], - "decoder_upsample_kernel_sizes": [16, 16, 4, 4], - "decoder_resblock_kernel_sizes": [3, 7, 11], - "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "use_weight_norm_in_decoder": True, - "posterior_encoder_kernel_size": 5, - "posterior_encoder_layers": 16, - "posterior_encoder_stacks": 1, - "posterior_encoder_base_dilation": 1, - "posterior_encoder_dropout_rate": 0.0, - "use_weight_norm_in_posterior_encoder": True, - "flow_flows": 4, - "flow_kernel_size": 5, - "flow_base_dilation": 1, - "flow_layers": 4, - "flow_dropout_rate": 0.0, - "use_weight_norm_in_flow": True, - "use_only_mean_in_flow": True, - "stochastic_duration_predictor_kernel_size": 3, - "stochastic_duration_predictor_dropout_rate": 0.5, - "stochastic_duration_predictor_flows": 4, - "stochastic_duration_predictor_dds_conv_layers": 3, - }, - # discriminator related - discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", - discriminator_params: Dict[str, Any] = { - "scales": 1, - "scale_downsample_pooling": "AvgPool1d", - "scale_downsample_pooling_params": { - "kernel_size": 4, - "stride": 2, - "padding": 2, - }, - "scale_discriminator_params": { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [15, 41, 5, 3], - "channels": 128, - "max_downsample_channels": 1024, - "max_groups": 16, - "bias": True, - "downsample_scales": [2, 2, 4, 4, 1], - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - "use_weight_norm": True, - "use_spectral_norm": False, - }, - "follow_official_norm": False, - "periods": [2, 3, 5, 7, 11], - "period_discriminator_params": { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [5, 3], - "channels": 32, - "downsample_scales": [3, 3, 3, 3, 1], - "max_downsample_channels": 1024, - "bias": True, - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - "use_weight_norm": True, - "use_spectral_norm": False, - }, - }, - # loss related - generator_adv_loss_params: Dict[str, Any] = { - "average_by_discriminators": False, - "loss_type": "mse", - }, - discriminator_adv_loss_params: Dict[str, Any] = { - "average_by_discriminators": False, - "loss_type": "mse", - }, - feat_match_loss_params: Dict[str, Any] = { - "average_by_discriminators": False, - "average_by_layers": False, - "include_final_outputs": True, - }, - mel_loss_params: Dict[str, Any] = { - "frame_shift": 256, - "frame_length": 1024, - "n_mels": 80, - }, - lambda_adv: float = 1.0, - lambda_mel: float = 45.0, - lambda_feat_match: float = 2.0, - lambda_dur: float = 1.0, - lambda_kl: float = 1.0, - cache_generator_outputs: bool = True, - ): - """Initialize VITS module. - - Args: - idim (int): Input vocabulary size. - odim (int): Acoustic feature dimension. The actual output channels will - be 1 since VITS is the end-to-end text-to-wave model but for the - compatibility odim is used to indicate the acoustic feature dimension. - sampling_rate (int): Sampling rate, not used for the training but it will - be referred in saving waveform during the inference. - model_type (str): If not empty, must be one of: low, medium, high - generator_type (str): Generator type. - generator_params (Dict[str, Any]): Parameter dict for generator. - discriminator_type (str): Discriminator type. - discriminator_params (Dict[str, Any]): Parameter dict for discriminator. - generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator - adversarial loss. - discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for - discriminator adversarial loss. - feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. - mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. - lambda_adv (float): Loss scaling coefficient for adversarial loss. - lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. - lambda_feat_match (float): Loss scaling coefficient for feat match loss. - lambda_dur (float): Loss scaling coefficient for duration loss. - lambda_kl (float): Loss scaling coefficient for KL divergence loss. - cache_generator_outputs (bool): Whether to cache generator outputs. - - """ - super().__init__() - - generator_params = copy.deepcopy(generator_params) - discriminator_params = copy.deepcopy(discriminator_params) - generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params) - discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params) - feat_match_loss_params = copy.deepcopy(feat_match_loss_params) - mel_loss_params = copy.deepcopy(mel_loss_params) - - if model_type != "": - assert model_type in ("low", "medium", "high"), model_type - if model_type == "low": - generator_params.update(LOW_CONFIG) - elif model_type == "medium": - generator_params.update(MEDIUM_CONFIG) - elif model_type == "high": - generator_params.update(HIGH_CONFIG) - else: - raise ValueError(f"Unknown model_type: ${model_type}") - - # define modules - generator_class = AVAILABLE_GENERATERS[generator_type] - if generator_type == "vits_generator": - # NOTE(kan-bayashi): Update parameters for the compatibility. - # The idim and odim is automatically decided from input data, - # where idim represents #vocabularies and odim represents - # the input acoustic feature dimension. - generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) - self.generator = generator_class( - **generator_params, - ) - discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] - self.discriminator = discriminator_class( - **discriminator_params, - ) - self.generator_adv_loss = GeneratorAdversarialLoss( - **generator_adv_loss_params, - ) - self.discriminator_adv_loss = DiscriminatorAdversarialLoss( - **discriminator_adv_loss_params, - ) - self.feat_match_loss = FeatureMatchLoss( - **feat_match_loss_params, - ) - mel_loss_params.update(sampling_rate=sampling_rate) - self.mel_loss = MelSpectrogramLoss( - **mel_loss_params, - ) - self.kl_loss = KLDivergenceLoss() - - # coefficients - self.lambda_adv = lambda_adv - self.lambda_mel = lambda_mel - self.lambda_kl = lambda_kl - self.lambda_feat_match = lambda_feat_match - self.lambda_dur = lambda_dur - - # cache - self.cache_generator_outputs = cache_generator_outputs - self._cache = None - - # store sampling rate for saving wav file - # (not used for the training) - self.sampling_rate = sampling_rate - - # store parameters for test compatibility - self.spks = self.generator.spks - self.langs = self.generator.langs - self.spk_embed_dim = self.generator.spk_embed_dim - - def forward( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: torch.Tensor, - feats_lengths: torch.Tensor, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - return_sample: bool = False, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - forward_generator: bool = True, - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - """Perform generator forward. - - Args: - text (Tensor): Text index tensor (B, T_text). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, T_feats, aux_channels). - feats_lengths (Tensor): Feature length tensor (B,). - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - forward_generator (bool): Whether to forward generator. - - Returns: - - loss (Tensor): Loss scalar tensor. - - stats (Dict[str, float]): Statistics to be monitored. - """ - if forward_generator: - return self._forward_generator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - speech=speech, - speech_lengths=speech_lengths, - return_sample=return_sample, - sids=sids, - spembs=spembs, - lids=lids, - ) - else: - return self._forward_discrminator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - speech=speech, - speech_lengths=speech_lengths, - sids=sids, - spembs=spembs, - lids=lids, - ) - - def _forward_generator( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: torch.Tensor, - feats_lengths: torch.Tensor, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - return_sample: bool = False, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - """Perform generator forward. - - Args: - text (Tensor): Text index tensor (B, T_text). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, T_feats, aux_channels). - feats_lengths (Tensor): Feature length tensor (B,). - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - - Returns: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - """ - # setup - feats = feats.transpose(1, 2) - speech = speech.unsqueeze(1) - - # calculate generator outputs - reuse_cache = True - if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False - outs = self.generator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - sids=sids, - spembs=spembs, - lids=lids, - ) - else: - outs = self._cache - - # store cache - if self.training and self.cache_generator_outputs and not reuse_cache: - self._cache = outs - - # parse outputs - speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs - _, z_p, m_p, logs_p, _, logs_q = outs_ - speech_ = get_segments( - x=speech, - start_idxs=start_idxs * self.generator.upsample_factor, - segment_size=self.generator.segment_size * self.generator.upsample_factor, - ) - - # calculate discriminator outputs - p_hat = self.discriminator(speech_hat_) - with torch.no_grad(): - # do not store discriminator gradient in generator turn - p = self.discriminator(speech_) - - # calculate losses - with autocast(enabled=False): - if not return_sample: - mel_loss = self.mel_loss(speech_hat_, speech_) - else: - mel_loss, (mel_hat_, mel_) = self.mel_loss( - speech_hat_, speech_, return_mel=True - ) - kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) - dur_loss = torch.sum(dur_nll.float()) - adv_loss = self.generator_adv_loss(p_hat) - feat_match_loss = self.feat_match_loss(p_hat, p) - - mel_loss = mel_loss * self.lambda_mel - kl_loss = kl_loss * self.lambda_kl - dur_loss = dur_loss * self.lambda_dur - adv_loss = adv_loss * self.lambda_adv - feat_match_loss = feat_match_loss * self.lambda_feat_match - loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss - - stats = dict( - generator_loss=loss.item(), - generator_mel_loss=mel_loss.item(), - generator_kl_loss=kl_loss.item(), - generator_dur_loss=dur_loss.item(), - generator_adv_loss=adv_loss.item(), - generator_feat_match_loss=feat_match_loss.item(), - ) - - if return_sample: - stats["returned_sample"] = ( - speech_hat_[0].data.cpu().numpy(), - speech_[0].data.cpu().numpy(), - mel_hat_[0].data.cpu().numpy(), - mel_[0].data.cpu().numpy(), - ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - - return loss, stats - - def _forward_discrminator( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: torch.Tensor, - feats_lengths: torch.Tensor, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - """Perform discriminator forward. - - Args: - text (Tensor): Text index tensor (B, T_text). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, T_feats, aux_channels). - feats_lengths (Tensor): Feature length tensor (B,). - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - - Returns: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - """ - # setup - feats = feats.transpose(1, 2) - speech = speech.unsqueeze(1) - - # calculate generator outputs - reuse_cache = True - if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False - outs = self.generator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - sids=sids, - spembs=spembs, - lids=lids, - ) - else: - outs = self._cache - - # store cache - if self.cache_generator_outputs and not reuse_cache: - self._cache = outs - - # parse outputs - speech_hat_, _, _, start_idxs, *_ = outs - speech_ = get_segments( - x=speech, - start_idxs=start_idxs * self.generator.upsample_factor, - segment_size=self.generator.segment_size * self.generator.upsample_factor, - ) - - # calculate discriminator outputs - p_hat = self.discriminator(speech_hat_.detach()) - p = self.discriminator(speech_) - - # calculate losses - with autocast(enabled=False): - real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) - loss = real_loss + fake_loss - - stats = dict( - discriminator_loss=loss.item(), - discriminator_real_loss=real_loss.item(), - discriminator_fake_loss=fake_loss.item(), - ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - - return loss, stats - - def inference( - self, - text: torch.Tensor, - feats: Optional[torch.Tensor] = None, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - durations: Optional[torch.Tensor] = None, - noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, - alpha: float = 1.0, - max_len: Optional[int] = None, - use_teacher_forcing: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Run inference for single sample. - - Args: - text (Tensor): Input text index tensor (T_text,). - feats (Tensor): Feature tensor (T_feats, aux_channels). - sids (Tensor): Speaker index tensor (1,). - spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). - lids (Tensor): Language index tensor (1,). - durations (Tensor): Ground-truth duration tensor (T_text,). - noise_scale (float): Noise scale value for flow. - noise_scale_dur (float): Noise scale value for duration predictor. - alpha (float): Alpha parameter to control the speed of generated speech. - max_len (Optional[int]): Maximum length. - use_teacher_forcing (bool): Whether to use teacher forcing. - - Returns: - * wav (Tensor): Generated waveform tensor (T_wav,). - * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). - * duration (Tensor): Predicted duration tensor (T_text,). - """ - # setup - text = text[None] - text_lengths = torch.tensor( - [text.size(1)], - dtype=torch.long, - device=text.device, - ) - if sids is not None: - sids = sids.view(1) - if lids is not None: - lids = lids.view(1) - if durations is not None: - durations = durations.view(1, 1, -1) - - # inference - if use_teacher_forcing: - assert feats is not None - feats = feats[None].transpose(1, 2) - feats_lengths = torch.tensor( - [feats.size(2)], - dtype=torch.long, - device=feats.device, - ) - wav, att_w, dur = self.generator.inference( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - sids=sids, - spembs=spembs, - lids=lids, - max_len=max_len, - use_teacher_forcing=use_teacher_forcing, - ) - else: - wav, att_w, dur = self.generator.inference( - text=text, - text_lengths=text_lengths, - sids=sids, - spembs=spembs, - lids=lids, - dur=durations, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - alpha=alpha, - max_len=max_len, - ) - return wav.view(-1), att_w[0], dur[0] - - def inference_batch( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - durations: Optional[torch.Tensor] = None, - noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, - alpha: float = 1.0, - max_len: Optional[int] = None, - use_teacher_forcing: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Run inference for one batch. - - Args: - text (Tensor): Input text index tensor (B, T_text). - text_lengths (Tensor): Input text index tensor (B,). - sids (Tensor): Speaker index tensor (B,). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Tensor): Language index tensor (B,). - noise_scale (float): Noise scale value for flow. - noise_scale_dur (float): Noise scale value for duration predictor. - alpha (float): Alpha parameter to control the speed of generated speech. - max_len (Optional[int]): Maximum length. - - Returns: - * wav (Tensor): Generated waveform tensor (B, T_wav). - * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). - * duration (Tensor): Predicted duration tensor (B, T_text). - """ - # inference - wav, att_w, dur = self.generator.inference( - text=text, - text_lengths=text_lengths, - sids=sids, - spembs=spembs, - lids=lids, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - alpha=alpha, - max_len=max_len, - ) - return wav, att_w, dur diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py deleted file mode 100644 index 5db461d5c..000000000 --- a/egs/ljspeech/TTS/vits/wavenet.py +++ /dev/null @@ -1,348 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""WaveNet modules. - -This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. - -""" - -import logging -import math -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F - - -class WaveNet(torch.nn.Module): - """WaveNet with global conditioning.""" - - def __init__( - self, - in_channels: int = 1, - out_channels: int = 1, - kernel_size: int = 3, - layers: int = 30, - stacks: int = 3, - base_dilation: int = 2, - residual_channels: int = 64, - aux_channels: int = -1, - gate_channels: int = 128, - skip_channels: int = 64, - global_channels: int = -1, - dropout_rate: float = 0.0, - bias: bool = True, - use_weight_norm: bool = True, - use_first_conv: bool = False, - use_last_conv: bool = False, - scale_residual: bool = False, - scale_skip_connect: bool = False, - ): - """Initialize WaveNet module. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - kernel_size (int): Kernel size of dilated convolution. - layers (int): Number of residual block layers. - stacks (int): Number of stacks i.e., dilation cycles. - base_dilation (int): Base dilation factor. - residual_channels (int): Number of channels in residual conv. - gate_channels (int): Number of channels in gated conv. - skip_channels (int): Number of channels in skip conv. - aux_channels (int): Number of channels for local conditioning feature. - global_channels (int): Number of channels for global conditioning feature. - dropout_rate (float): Dropout rate. 0.0 means no dropout applied. - bias (bool): Whether to use bias parameter in conv layer. - use_weight_norm (bool): Whether to use weight norm. If set to true, it will - be applied to all of the conv layers. - use_first_conv (bool): Whether to use the first conv layers. - use_last_conv (bool): Whether to use the last conv layers. - scale_residual (bool): Whether to scale the residual outputs. - scale_skip_connect (bool): Whether to scale the skip connection outputs. - - """ - super().__init__() - self.layers = layers - self.stacks = stacks - self.kernel_size = kernel_size - self.base_dilation = base_dilation - self.use_first_conv = use_first_conv - self.use_last_conv = use_last_conv - self.scale_skip_connect = scale_skip_connect - - # check the number of layers and stacks - assert layers % stacks == 0 - layers_per_stack = layers // stacks - - # define first convolution - if self.use_first_conv: - self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) - - # define residual blocks - self.conv_layers = torch.nn.ModuleList() - for layer in range(layers): - dilation = base_dilation ** (layer % layers_per_stack) - conv = ResidualBlock( - kernel_size=kernel_size, - residual_channels=residual_channels, - gate_channels=gate_channels, - skip_channels=skip_channels, - aux_channels=aux_channels, - global_channels=global_channels, - dilation=dilation, - dropout_rate=dropout_rate, - bias=bias, - scale_residual=scale_residual, - ) - self.conv_layers += [conv] - - # define output layers - if self.use_last_conv: - self.last_conv = torch.nn.Sequential( - torch.nn.ReLU(inplace=True), - Conv1d1x1(skip_channels, skip_channels, bias=True), - torch.nn.ReLU(inplace=True), - Conv1d1x1(skip_channels, out_channels, bias=True), - ) - - # apply weight norm - if use_weight_norm: - self.apply_weight_norm() - - def forward( - self, - x: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - c: Optional[torch.Tensor] = None, - g: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input noise signal (B, 1, T) if use_first_conv else - (B, residual_channels, T). - x_mask (Optional[Tensor]): Mask tensor (B, 1, T). - c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). - g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). - - Returns: - Tensor: Output tensor (B, out_channels, T) if use_last_conv else - (B, residual_channels, T). - - """ - # encode to hidden representation - if self.use_first_conv: - x = self.first_conv(x) - - # residual block - skips = 0.0 - for f in self.conv_layers: - x, h = f(x, x_mask=x_mask, c=c, g=g) - skips = skips + h - x = skips - if self.scale_skip_connect: - x = x * math.sqrt(1.0 / len(self.conv_layers)) - - # apply final layers - if self.use_last_conv: - x = self.last_conv(x) - - return x - - def remove_weight_norm(self): - """Remove weight normalization module from all of the layers.""" - - def _remove_weight_norm(m: torch.nn.Module): - try: - logging.debug(f"Weight norm is removed from {m}.") - torch.nn.utils.remove_weight_norm(m) - except ValueError: # this module didn't have weight norm - return - - self.apply(_remove_weight_norm) - - def apply_weight_norm(self): - """Apply weight normalization module from all of the layers.""" - - def _apply_weight_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): - torch.nn.utils.weight_norm(m) - logging.debug(f"Weight norm is applied to {m}.") - - self.apply(_apply_weight_norm) - - @staticmethod - def _get_receptive_field_size( - layers: int, - stacks: int, - kernel_size: int, - base_dilation: int, - ) -> int: - assert layers % stacks == 0 - layers_per_cycle = layers // stacks - dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] - return (kernel_size - 1) * sum(dilations) + 1 - - @property - def receptive_field_size(self) -> int: - """Return receptive field size.""" - return self._get_receptive_field_size( - self.layers, self.stacks, self.kernel_size, self.base_dilation - ) - - -class Conv1d(torch.nn.Conv1d): - """Conv1d module with customized initialization.""" - - def __init__(self, *args, **kwargs): - """Initialize Conv1d module.""" - super().__init__(*args, **kwargs) - - def reset_parameters(self): - """Reset parameters.""" - torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") - if self.bias is not None: - torch.nn.init.constant_(self.bias, 0.0) - - -class Conv1d1x1(Conv1d): - """1x1 Conv1d with customized initialization.""" - - def __init__(self, in_channels: int, out_channels: int, bias: bool): - """Initialize 1x1 Conv1d module.""" - super().__init__( - in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias - ) - - -class ResidualBlock(torch.nn.Module): - """Residual block module in WaveNet.""" - - def __init__( - self, - kernel_size: int = 3, - residual_channels: int = 64, - gate_channels: int = 128, - skip_channels: int = 64, - aux_channels: int = 80, - global_channels: int = -1, - dropout_rate: float = 0.0, - dilation: int = 1, - bias: bool = True, - scale_residual: bool = False, - ): - """Initialize ResidualBlock module. - - Args: - kernel_size (int): Kernel size of dilation convolution layer. - residual_channels (int): Number of channels for residual connection. - skip_channels (int): Number of channels for skip connection. - aux_channels (int): Number of local conditioning channels. - dropout (float): Dropout probability. - dilation (int): Dilation factor. - bias (bool): Whether to add bias parameter in convolution layers. - scale_residual (bool): Whether to scale the residual outputs. - - """ - super().__init__() - self.dropout_rate = dropout_rate - self.residual_channels = residual_channels - self.skip_channels = skip_channels - self.scale_residual = scale_residual - - # check - assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." - assert gate_channels % 2 == 0 - - # dilation conv - padding = (kernel_size - 1) // 2 * dilation - self.conv = Conv1d( - residual_channels, - gate_channels, - kernel_size, - padding=padding, - dilation=dilation, - bias=bias, - ) - - # local conditioning - if aux_channels > 0: - self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) - else: - self.conv1x1_aux = None - - # global conditioning - if global_channels > 0: - self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) - else: - self.conv1x1_glo = None - - # conv output is split into two groups - gate_out_channels = gate_channels // 2 - - # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency - # (integrate res 1x1 + skip 1x1 convs) - self.conv1x1_out = Conv1d1x1( - gate_out_channels, residual_channels + skip_channels, bias=bias - ) - - def forward( - self, - x: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - c: Optional[torch.Tensor] = None, - g: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, residual_channels, T). - x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). - c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - - Returns: - Tensor: Output tensor for residual connection (B, residual_channels, T). - Tensor: Output tensor for skip connection (B, skip_channels, T). - - """ - residual = x - x = F.dropout(x, p=self.dropout_rate, training=self.training) - x = self.conv(x) - - # split into two part for gated activation - splitdim = 1 - xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) - - # local conditioning - if c is not None: - c = self.conv1x1_aux(c) - ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) - xa, xb = xa + ca, xb + cb - - # global conditioning - if g is not None: - g = self.conv1x1_glo(g) - ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) - xa, xb = xa + ga, xb + gb - - x = torch.tanh(xa) * torch.sigmoid(xb) - - # residual + skip 1x1 conv - x = self.conv1x1_out(x) - if x_mask is not None: - x = x * x_mask - - # split integrated conv results - x, s = x.split([self.residual_channels, self.skip_channels], dim=1) - - # for residual connection - x = x + residual - if self.scale_residual: - x = x * math.sqrt(0.5) - - return x, s diff --git a/egs/mdcc/ASR/README.md b/egs/mdcc/ASR/README.md deleted file mode 100644 index 112845b73..000000000 --- a/egs/mdcc/ASR/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Introduction - -Multi-Domain Cantonese Corpus (MDCC), consists of 73.6 hours of clean read speech paired with -transcripts, collected from Cantonese audiobooks from Hong Kong. It comprises philosophy, -politics, education, culture, lifestyle and family domains, covering a wide range of topics. - -Manuscript can be found at: https://arxiv.org/abs/2201.02419 - -# Transducers - - - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|-----------------------------| -| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 | - -The decoder is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/mdcc/ASR/RESULTS.md b/egs/mdcc/ASR/RESULTS.md deleted file mode 100644 index ff7ddc957..000000000 --- a/egs/mdcc/ASR/RESULTS.md +++ /dev/null @@ -1,41 +0,0 @@ -## Results - -#### Zipformer - -See - -[./zipformer](./zipformer) - -##### normal-scaled model, number of model parameters: 74470867, i.e., 74.47 M - -| | test | valid | comment | -|------------------------|------|-------|-----------------------------------------| -| greedy search | 7.45 | 7.51 | --epoch 45 --avg 35 | -| modified beam search | 6.68 | 6.73 | --epoch 45 --avg 35 | -| fast beam search | 7.22 | 7.28 | --epoch 45 --avg 35 | - -The training command: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./zipformer/train.py \ - --world-size 4 \ - --start-epoch 1 \ - --num-epochs 50 \ - --use-fp16 1 \ - --exp-dir ./zipformer/exp \ - --max-duration 1000 -``` - -The decoding command: - -``` - ./zipformer/decode.py \ - --epoch 45 \ - --avg 35 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search # modified_beam_search -``` - -The pretrained model is available at: https://huggingface.co/zrjin/icefall-asr-mdcc-zipformer-2024-03-11/ \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_hlg.py b/egs/mdcc/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/mdcc/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_hlg_using_openfst.py b/egs/mdcc/ASR/local/compile_hlg_using_openfst.py deleted file mode 120000 index d34edd7f3..000000000 --- a/egs/mdcc/ASR/local/compile_hlg_using_openfst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg_using_openfst.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_lg.py b/egs/mdcc/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/mdcc/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compute_fbank_mdcc.py b/egs/mdcc/ASR/local/compute_fbank_mdcc.py deleted file mode 100755 index 647b21127..000000000 --- a/egs/mdcc/ASR/local/compute_fbank_mdcc.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aishell dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_mdcc( - num_mel_bins: int = 80, - perturb_speed: bool = False, - whisper_fbank: bool = False, - output_dir: str = "data/fbank", -): - src_dir = Path("data/manifests") - output_dir = Path(output_dir) - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "train", - "valid", - "test", - ) - prefix = "mdcc" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info("Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - parser.add_argument( - "--output-dir", - type=str, - default="data/fbank", - help="Output directory. Default: data/fbank.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_mdcc( - num_mel_bins=args.num_mel_bins, - perturb_speed=args.perturb_speed, - whisper_fbank=args.whisper_fbank, - output_dir=args.output_dir, - ) diff --git a/egs/mdcc/ASR/local/display_manifest_statistics.py b/egs/mdcc/ASR/local/display_manifest_statistics.py deleted file mode 100755 index 27cf8c943..000000000 --- a/egs/mdcc/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - path = "./data/fbank/mdcc_cuts_train.jsonl.gz" - path = "./data/fbank/mdcc_cuts_valid.jsonl.gz" - path = "./data/fbank/mdcc_cuts_test.jsonl.gz" - - cuts = load_manifest_lazy(path) - cuts.describe(full=True) - - -if __name__ == "__main__": - main() - -""" -data/fbank/mdcc_cuts_train.jsonl.gz (with speed perturbation) -_________________________________________ -_ Cuts count: _ 195360 -_________________________________________ -_ Total duration (hh:mm:ss) _ 173:44:59 -_________________________________________ -_ mean _ 3.2 -_________________________________________ -_ std _ 2.1 -_________________________________________ -_ min _ 0.2 -_________________________________________ -_ 25% _ 1.8 -_________________________________________ -_ 50% _ 2.7 -_________________________________________ -_ 75% _ 4.0 -_________________________________________ -_ 99% _ 11.0 _ -_________________________________________ -_ 99.5% _ 12.4 _ -_________________________________________ -_ 99.9% _ 14.8 _ -_________________________________________ -_ max _ 16.7 _ -_________________________________________ -_ Recordings available: _ 195360 _ -_________________________________________ -_ Features available: _ 195360 _ -_________________________________________ -_ Supervisions available: _ 195360 _ -_________________________________________ - -data/fbank/mdcc_cuts_valid.jsonl.gz -________________________________________ -_ Cuts count: _ 5663 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 05:03:12 _ -________________________________________ -_ mean _ 3.2 _ -________________________________________ -_ std _ 2.0 _ -________________________________________ -_ min _ 0.3 _ -________________________________________ -_ 25% _ 1.8 _ -________________________________________ -_ 50% _ 2.7 _ -________________________________________ -_ 75% _ 4.0 _ -________________________________________ -_ 99% _ 10.9 _ -________________________________________ -_ 99.5% _ 12.3 _ -________________________________________ -_ 99.9% _ 14.4 _ -________________________________________ -_ max _ 14.8 _ -________________________________________ -_ Recordings available: _ 5663 _ -________________________________________ -_ Features available: _ 5663 _ -________________________________________ -_ Supervisions available: _ 5663 _ -________________________________________ - -data/fbank/mdcc_cuts_test.jsonl.gz -________________________________________ -_ Cuts count: _ 12492 _ -________________________________________ -_ Total duration (hh:mm:ss) _ 11:00:31 _ -________________________________________ -_ mean _ 3.2 _ -________________________________________ -_ std _ 2.0 _ -________________________________________ -_ min _ 0.2 _ -________________________________________ -_ 25% _ 1.8 _ -________________________________________ -_ 50% _ 2.7 _ -________________________________________ -_ 75% _ 4.0 _ -________________________________________ -_ 99% _ 10.5 _ -________________________________________ -_ 99.5% _ 12.1 _ -________________________________________ -_ 99.9% _ 14.0 _ -________________________________________ -_ max _ 14.8 _ -________________________________________ -_ Recordings available: _ 12492 _ -________________________________________ -_ Features available: _ 12492 _ -________________________________________ -_ Supervisions available: _ 12492 _ -________________________________________ - -""" diff --git a/egs/mdcc/ASR/local/prepare_char.py b/egs/mdcc/ASR/local/prepare_char.py deleted file mode 120000 index 42743b544..000000000 --- a/egs/mdcc/ASR/local/prepare_char.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_char_lm_training_data.py b/egs/mdcc/ASR/local/prepare_char_lm_training_data.py deleted file mode 120000 index 2374cafdd..000000000 --- a/egs/mdcc/ASR/local/prepare_char_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_char_lm_training_data.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_lang.py b/egs/mdcc/ASR/local/prepare_lang.py deleted file mode 120000 index bee8d5f03..000000000 --- a/egs/mdcc/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_lang_fst.py b/egs/mdcc/ASR/local/prepare_lang_fst.py deleted file mode 120000 index c5787c534..000000000 --- a/egs/mdcc/ASR/local/prepare_lang_fst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/preprocess_mdcc.py b/egs/mdcc/ASR/local/preprocess_mdcc.py deleted file mode 100755 index cd0dc7de8..000000000 --- a/egs/mdcc/ASR/local/preprocess_mdcc.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes a text file "data/lang_char/text" as input, the file consist of -lines each containing a transcript, applies text norm and generates the following -files in the directory "data/lang_char": - - text_norm - - words.txt - - words_no_ids.txt - - text_words_segmentation -""" - -import argparse -import logging -from pathlib import Path -from typing import List - -import pycantonese -from tqdm.auto import tqdm - -from icefall.utils import is_cjk - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare char lexicon", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - "-i", - default="data/lang_char/text", - type=str, - help="The input text file", - ) - parser.add_argument( - "--output-dir", - "-o", - default="data/lang_char", - type=str, - help="The output directory", - ) - return parser - - -def get_norm_lines(lines: List[str]) -> List[str]: - def _text_norm(text: str) -> str: - # to cope with the protocol for transcription: - # When taking notes, the annotators adhere to the following guidelines: - # 1) If the audio contains pure music, the annotators mark the label - # "(music)" in the file name of its transcript. 2) If the utterance - # contains one or several sentences with background music or noise, the - # annotators mark the label "(music)" before each sentence in the transcript. - # 3) The annotators use {} symbols to enclose words they are uncertain - # about, for example, {梁佳佳},我是{}人. - - # here we manually fix some errors in the transcript - - return ( - text.strip() - .replace("(music)", "") - .replace("(music", "") - .replace("{", "") - .replace("}", "") - .replace("BB所以就指腹為親喇", "BB 所以就指腹為親喇") - .upper() - ) - - return [_text_norm(line) for line in lines] - - -def get_word_segments(lines: List[str]) -> List[str]: - # the current pycantonese segmenter does not handle the case when the input - # is code switching, so we need to handle it separately - - new_lines = [] - - for line in tqdm(lines, desc="Segmenting lines"): - try: - # code switching - if len(line.strip().split(" ")) > 1: - segments = [] - for segment in line.strip().split(" "): - if segment.strip() == "": - continue - try: - if not is_cjk(segment[0]): # en segment - segments.append(segment) - else: # zh segment - segments.extend(pycantonese.segment(segment)) - except Exception as e: - logging.error(f"Failed to process segment: {segment}") - raise e - new_lines.append(" ".join(segments) + "\n") - # not code switching - else: - new_lines.append(" ".join(pycantonese.segment(line)) + "\n") - except Exception as e: - logging.error(f"Failed to process line: {line}") - raise e - return new_lines - - -def get_words(lines: List[str]) -> List[str]: - words = set() - for line in tqdm(lines, desc="Getting words"): - words.update(line.strip().split(" ")) - return list(words) - - -if __name__ == "__main__": - parser = get_parser() - args = parser.parse_args() - - input_file = Path(args.input_file) - output_dir = Path(args.output_dir) - - assert output_dir.is_dir(), f"{output_dir} does not exist" - assert input_file.is_file(), f"{input_file} does not exist" - - lines = input_file.read_text(encoding="utf-8").strip().split("\n") - - norm_lines = get_norm_lines(lines) - with open(output_dir / "text_norm", "w+", encoding="utf-8") as f: - f.writelines([line + "\n" for line in norm_lines]) - - text_words_segments = get_word_segments(norm_lines) - with open(output_dir / "text_words_segmentation", "w+", encoding="utf-8") as f: - f.writelines(text_words_segments) - - words = get_words(text_words_segments)[1:] # remove "\n" from words - with open(output_dir / "words_no_ids.txt", "w+", encoding="utf-8") as f: - f.writelines([word + "\n" for word in sorted(words)]) - - words = ( - ["", "!SIL", "", ""] - + sorted(words) - + ["#0", "", "<\s>"] - ) - - with open(output_dir / "words.txt", "w+", encoding="utf-8") as f: - f.writelines([f"{word} {i}\n" for i, word in enumerate(words)]) diff --git a/egs/mdcc/ASR/local/text2segments.py b/egs/mdcc/ASR/local/text2segments.py deleted file mode 100755 index 8ce7ab7e5..000000000 --- a/egs/mdcc/ASR/local/text2segments.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# 2022 Xiaomi Corp. (authors: Weiji Zhuang) -# 2024 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text", which refers to the transcript file for -MDCC: - - text -and generates the output file text_word_segmentation which is implemented -with word segmenting: - - text_words_segmentation -""" - -import argparse -from typing import List - -import pycantonese -from tqdm.auto import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Cantonese Word Segmentation for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - "-i", - default="data/lang_char/text", - type=str, - help="the input text file for MDCC", - ) - parser.add_argument( - "--output-file", - "-o", - default="data/lang_char/text_words_segmentation", - type=str, - help="the text implemented with words segmenting for MDCC", - ) - - return parser - - -def get_word_segments(lines: List[str]) -> List[str]: - return [ - " ".join(pycantonese.segment(line)) + "\n" - for line in tqdm(lines, desc="Segmenting lines") - ] - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - with open(input_file, "r", encoding="utf-8") as fr: - lines = fr.readlines() - - new_lines = get_word_segments(lines) - - with open(output_file, "w", encoding="utf-8") as fw: - fw.writelines(new_lines) - - -if __name__ == "__main__": - main() diff --git a/egs/mdcc/ASR/local/text2token.py b/egs/mdcc/ASR/local/text2token.py deleted file mode 120000 index 81e459d69..000000000 --- a/egs/mdcc/ASR/local/text2token.py +++ /dev/null @@ -1 +0,0 @@ -../../../aidatatang_200zh/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/mdcc/ASR/prepare.sh b/egs/mdcc/ASR/prepare.sh deleted file mode 100755 index f4d9bc47e..000000000 --- a/egs/mdcc/ASR/prepare.sh +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 -perturb_speed=true - - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/mdcc -# |-- README.md -# |-- audio/ -# |-- clip_info_rthk.csv -# |-- cnt_asr_metadata_full.csv -# |-- cnt_asr_test_metadata.csv -# |-- cnt_asr_train_metadata.csv -# |-- cnt_asr_valid_metadata.csv -# |-- data_statistic.py -# |-- length -# |-- podcast_447_2021.csv -# |-- test.txt -# |-- transcription/ -# `-- words_length -# You can download them from: -# https://drive.google.com/file/d/1epfYMMhXdBKA6nxPgUugb2Uj4DllSxkn/view?usp=drive_link -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" - - # If you have pre-downloaded it to /path/to/mdcc, - # you can create a symlink - # - # ln -sfv /path/to/mdcc $dl_dir/mdcc - # - # The directory structure is - # mdcc/ - # |-- README.md - # |-- audio/ - # |-- clip_info_rthk.csv - # |-- cnt_asr_metadata_full.csv - # |-- cnt_asr_test_metadata.csv - # |-- cnt_asr_train_metadata.csv - # |-- cnt_asr_valid_metadata.csv - # |-- data_statistic.py - # |-- length - # |-- podcast_447_2021.csv - # |-- test.txt - # |-- transcription/ - # `-- words_length - - if [ ! -d $dl_dir/mdcc/audio ]; then - lhotse download mdcc $dl_dir - - # this will download and unzip dataset.zip to $dl_dir/ - - mv $dl_dir/dataset $dl_dir/mdcc - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/musan - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare MDCC manifest" - # We assume that you have downloaded the MDCC corpus - # to $dl_dir/mdcc - if [ ! -f data/manifests/.mdcc_manifests.done ]; then - log "Might take 40 minutes to traverse the directory." - mkdir -p data/manifests - lhotse prepare mdcc $dl_dir/mdcc data/manifests - touch data/manifests/.mdcc_manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for MDCC" - if [ ! -f data/fbank/.mdcc.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_mdcc.py --perturb-speed ${perturb_speed} - touch data/fbank/.mdcc.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -lang_char_dir=data/lang_char -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare char based lang" - mkdir -p $lang_char_dir - - # Prepare text. - # Note: in Linux, you can install jq with the following command: - # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - # 2. chmod +x ./jq - # 3. cp jq /usr/bin - if [ ! -f $lang_char_dir/text ]; then - gunzip -c data/manifests/mdcc_supervisions_train.jsonl.gz \ - |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ - > $lang_char_dir/train_text - - cat $lang_char_dir/train_text > $lang_char_dir/text - - gunzip -c data/manifests/mdcc_supervisions_test.jsonl.gz \ - |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ - > $lang_char_dir/valid_text - - cat $lang_char_dir/valid_text >> $lang_char_dir/text - - gunzip -c data/manifests/mdcc_supervisions_valid.jsonl.gz \ - |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ - > $lang_char_dir/test_text - - cat $lang_char_dir/test_text >> $lang_char_dir/text - fi - - if [ ! -f $lang_char_dir/text_words_segmentation ]; then - ./local/preprocess_mdcc.py --input-file $lang_char_dir/text \ - --output-dir $lang_char_dir - - mv $lang_char_dir/text $lang_char_dir/_text - cp $lang_char_dir/text_words_segmentation $lang_char_dir/text - fi - - if [ ! -f $lang_char_dir/tokens.txt ]; then - ./local/prepare_char.py --lang-dir $lang_char_dir - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare G" - - mkdir -p data/lm - - # Train LM on transcripts - if [ ! -f data/lm/3-gram.unpruned.arpa ]; then - python3 ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text $lang_char_dir/text_words_segmentation \ - -lm data/lm/3-gram.unpruned.arpa - fi - - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="$lang_char_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt - fi - - if [ ! -f $lang_char_dir/HLG.fst ]; then - ./local/prepare_lang_fst.py \ - --lang-dir $lang_char_dir \ - --ngram-G ./data/lm/G_3_gram_char.fst.txt - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compile LG & HLG" - - ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char - ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Generate LM training data" - - log "Processing char based data" - out_dir=data/lm_training_char - mkdir -p $out_dir $dl_dir/lm - - if [ ! -f $dl_dir/lm/mdcc-train-word.txt ]; then - ./local/text2segments.py --input-file $lang_char_dir/train_text \ - --output-file $dl_dir/lm/mdcc-train-word.txt - fi - - # training words - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $dl_dir/lm/mdcc-train-word.txt \ - --lm-archive $out_dir/lm_data.pt - - # valid words - if [ ! -f $dl_dir/lm/mdcc-valid-word.txt ]; then - ./local/text2segments.py --input-file $lang_char_dir/valid_text \ - --output-file $dl_dir/lm/mdcc-valid-word.txt - fi - - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $dl_dir/lm/mdcc-valid-word.txt \ - --lm-archive $out_dir/lm_data_valid.pt - - # test words - if [ ! -f $dl_dir/lm/mdcc-test-word.txt ]; then - ./local/text2segments.py --input-file $lang_char_dir/test_text \ - --output-file $dl_dir/lm/mdcc-test-word.txt - fi - - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $dl_dir/lm/mdcc-test-word.txt \ - --lm-archive $out_dir/lm_data_test.pt -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Sort LM training data" - # Sort LM training data by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of tokens - # in a sentence. - - out_dir=data/lm_training_char - mkdir -p $out_dir - ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data_valid.pt \ - --out-lm-data $out_dir/sorted_lm_data-valid.pt \ - --out-statistics $out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data_test.pt \ - --out-lm-data $out_dir/sorted_lm_data-test.pt \ - --out-statistics $out_dir/statistics-test.txt -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Train RNN LM model" - python ../../../icefall/rnn_lm/train.py \ - --start-epoch 0 \ - --world-size 1 \ - --num-epochs 20 \ - --use-fp16 0 \ - --embedding-dim 512 \ - --hidden-dim 512 \ - --num-layers 2 \ - --batch-size 400 \ - --exp-dir rnnlm_char/exp \ - --lm-data $out_dir/sorted_lm_data.pt \ - --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ - --vocab-size 4336 \ - --master-port 12345 -fi diff --git a/egs/mdcc/ASR/shared b/egs/mdcc/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/mdcc/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/__init__.py b/egs/mdcc/ASR/zipformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/mdcc/ASR/zipformer/asr_datamodule.py b/egs/mdcc/ASR/zipformer/asr_datamodule.py deleted file mode 100644 index 1f49b6520..000000000 --- a/egs/mdcc/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1,382 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# Copyright 2024 Xiaomi Corporation (Author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class MdccAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures() - ), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.manifest_dir / "mdcc_cuts_train.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get valid cuts") - return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_valid.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_test.jsonl.gz") diff --git a/egs/mdcc/ASR/zipformer/beam_search.py b/egs/mdcc/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/mdcc/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/decode.py b/egs/mdcc/ASR/zipformer/decode.py deleted file mode 100755 index ce104baf7..000000000 --- a/egs/mdcc/ASR/zipformer/decode.py +++ /dev/null @@ -1,813 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Mingshuang Luo, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (trivial_graph) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(4) fast beam search (LG) -./zipformer/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import MdccAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}_" + key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list("".join(text.split())) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - MdccAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest_oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - mdcc = MdccAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - valid_cuts = mdcc.valid_cuts() - valid_cuts = valid_cuts.filter(remove_short_utt) - valid_dl = mdcc.valid_dataloaders(valid_cuts) - - test_cuts = mdcc.test_cuts() - test_cuts = test_cuts.filter(remove_short_utt) - test_dl = mdcc.test_dataloaders(test_cuts) - - test_sets = ["valid", "test"] - test_dls = [valid_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/mdcc/ASR/zipformer/decode_stream.py b/egs/mdcc/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/mdcc/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/decoder.py b/egs/mdcc/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/mdcc/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/encoder_interface.py b/egs/mdcc/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/mdcc/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-ctc.py deleted file mode 120000 index f9d756352..000000000 --- a/egs/mdcc/ASR/zipformer/export-onnx-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py deleted file mode 120000 index 652346001..000000000 --- a/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/mdcc/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx.py b/egs/mdcc/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/mdcc/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export.py b/egs/mdcc/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/mdcc/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/joiner.py b/egs/mdcc/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/mdcc/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/model.py b/egs/mdcc/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/mdcc/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/onnx_check.py b/egs/mdcc/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/mdcc/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/onnx_decode.py b/egs/mdcc/ASR/zipformer/onnx_decode.py deleted file mode 100755 index 1ed4a9fa1..000000000 --- a/egs/mdcc/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX exported models and uses them to decode the test sets. -""" - -import argparse -import logging -import time -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import MdccAsrDataModule -from lhotse.cut import Cut -from onnx_pretrained import OnnxModel, greedy_search - -from icefall.utils import setup_logger, store_transcripts, write_error_stats - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - - return parser - - -def decode_one_batch( - model: OnnxModel, token_table: k2.SymbolTable, batch: dict -) -> List[List[str]]: - """Decode one batch and return the result. - Currently it only greedy_search is supported. - - Args: - model: - The neural model. - token_table: - Mapping ids to tokens. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - - Returns: - Return the decoded results for each utterance. - """ - feature = batch["inputs"] - assert feature.ndim == 3 - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(dtype=torch.int64) - - encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) - - hyps = greedy_search( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) - - hyps = [[token_table[h] for h in hyp] for hyp in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: nn.Module, - token_table: k2.SymbolTable, -) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - The neural model. - token_table: - Mapping ids to tokens. - - Returns: - - A list of tuples. Each tuple contains three elements: - - cut_id, - - reference transcript, - - predicted result. - - The total duration (in seconds) of the dataset. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 10 - total_duration = 0 - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = list(ref_text) - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return results, total_duration - - -def save_results( - res_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -): - recog_path = res_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - errs_info = res_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("WER", file=f) - print(wer, file=f) - - s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - MdccAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - assert ( - args.decoding_method == "greedy_search" - ), "Only supports greedy_search currently." - res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" - - setup_logger(f"{res_dir}/log-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - logging.info(f"Device: {device}") - - token_table = k2.SymbolTable.from_file(args.tokens) - assert token_table[0] == "" - - logging.info(vars(args)) - - logging.info("About to create model") - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - # we need cut ids to display recognition results. - args.return_cuts = True - - mdcc = MdccAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - valid_cuts = mdcc.valid_cuts() - valid_cuts = valid_cuts.filter(remove_short_utt) - valid_dl = mdcc.valid_dataloaders(valid_cuts) - - test_cuts = mdcc.test_net_cuts() - test_cuts = test_cuts.filter(remove_short_utt) - test_dl = mdcc.test_dataloaders(test_cuts) - - test_sets = ["valid", "test"] - test_dl = [valid_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - start_time = time.time() - results, total_duration = decode_dataset( - dl=test_dl, model=model, token_table=token_table - ) - end_time = time.time() - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - - logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") - logging.info(f"Wave duration: {total_duration:.3f} s") - logging.info( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - - save_results(res_dir=res_dir, test_set_name=test_set, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/mdcc/ASR/zipformer/optim.py b/egs/mdcc/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/mdcc/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/scaling.py b/egs/mdcc/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/mdcc/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/scaling_converter.py b/egs/mdcc/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/mdcc/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/streaming_beam_search.py b/egs/mdcc/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/mdcc/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/streaming_decode.py b/egs/mdcc/ASR/zipformer/streaming_decode.py deleted file mode 100755 index dadb0b55f..000000000 --- a/egs/mdcc/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,881 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -from asr_datamodule import MdccAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="Path to the lang dir(containing lexicon, tokens, etc.)", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, - encoder_out=encoder_out, - streams=decode_streams, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - lexicon: - The Lexicon. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - list(decode_streams[i].ground_truth.strip()), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - key = f"greedy_search_{key}" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_{key}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}_{key}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - MdccAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - mdcc = MdccAsrDataModule(args) - - valid_cuts = mdcc.valid_cuts() - test_cuts = mdcc.test_cuts() - - test_sets = ["valid", "test"] - test_cuts = [valid_cuts, test_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/mdcc/ASR/zipformer/subsampling.py b/egs/mdcc/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/mdcc/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py deleted file mode 100755 index 730db7718..000000000 --- a/egs/mdcc/ASR/zipformer/train.py +++ /dev/null @@ -1,1346 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 50 \ - --start-epoch 1 \ - --exp-dir zipformer/exp \ - --max-duration 350 - -# For mix precision training: - -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 50 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import MdccAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="""Maximum left-contexts for causal training, measured in frames which will - be converted to a number of chunks. If splitting into chunks, - chunk left-context frames will be chosen randomly from this list; else not relevant.""", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="""The prune range for rnnt loss, it means how many symbols(context) - we are using to compute the loss""", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="""The scale to smooth the loss with lm - (output of prediction network) part.""", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="""The scale to smooth the loss with am (output of encoder network) part.""", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="""To get pruning ranges, we will calculate a simple version - loss(joiner is just addition), this simple loss also uses for - training (as a regularization item). We will scale the simple loss - with this parameter before adding to the final loss.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss = losses[:2] - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - mdcc = MdccAsrDataModule(args) - - train_cuts = mdcc.train_cuts() - valid_cuts = mdcc.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15 seconds - # - # Caution: There is a reason to select 15.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = mdcc.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) - - valid_dl = mdcc.valid_dataloaders(valid_cuts) - - if False and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - graph_compiler: - The compiler to encode texts to ids. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - MdccAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/mdcc/ASR/zipformer/zipformer.py b/egs/mdcc/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/mdcc/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/README.md b/egs/mgb2/ASR/README.md deleted file mode 100644 index 2bc4b000b..000000000 --- a/egs/mgb2/ASR/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# MGB2 - -The Multi-Dialect Broadcast News Arabic Speech Recognition (MGB-2): -The second edition of the Multi-Genre Broadcast (MGB-2) Challenge is -an evaluation of speech recognition and lightly supervised alignment -using TV recordings in Arabic. The speech data is broad and multi-genre, -spanning the whole range of TV output, and represents a challenging task for -speech technology. In 2016, the challenge featured two new Arabic tracks based -on TV data from Aljazeera. It was an official challenge at the 2016 IEEE -Workshop on Spoken Language Technology. The 1,200 hours MGB-2: from Aljazeera -TV programs have been manually captioned with no timing information. -QCRI Arabic ASR system has been used to recognize all programs. The ASR output -was used to align the manual captioning and produce speech segments for -training speech recognition. More than 20 hours from 2015 programs have been -transcribed verbatim and manually segmented. This data is split into a -development set of 10 hours, and a similar evaluation set of 10 hours. -Both the development and evaluation data have been released in the 2016 MGB -challenge - -Official reference: - -Ali, Ahmed, et al. "The MGB-2 challenge: Arabic multi-dialect broadcast media recognition." -2016 IEEE Spoken Language Technology Workshop (SLT). IEEE, 2016. - -IEEE link: https://ieeexplore.ieee.org/abstract/document/7846277 - -## Stateless Pruned Transducer Performance Record (after 30 epochs) - -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 15.52 | 15.28 | --epoch 18, --avg 5, --max-duration 200 | -| modified beam search | 13.88 | 13.7 | --epoch 18, --avg 5, --max-duration 200 | -| fast beam search | 14.62 | 14.36 | --epoch 18, --avg 5, --max-duration 200 | - -## Conformer-CTC Performance Record (after 40 epochs) - -| Decoding method | dev WER | test WER | -|---------------------------|------------|---------| -| attention-decoder | 15.62 | 15.01 | -| whole-lattice-rescoring | 15.89 | 15.08 | - - -See [RESULTS](/egs/mgb2/ASR/RESULTS.md) for details. diff --git a/egs/mgb2/ASR/RESULTS.md b/egs/mgb2/ASR/RESULTS.md deleted file mode 100644 index 2a7ea7664..000000000 --- a/egs/mgb2/ASR/RESULTS.md +++ /dev/null @@ -1,236 +0,0 @@ -# Results - - -### MGB2 all data BPE training results (Stateless Pruned Transducer) - -#### 2022-09-07 - -The WERs are - -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 15.52 | 15.28 | --epoch 18, --avg 5, --max-duration 200 | -| modified beam search | 13.88 | 13.7 | --epoch 18, --avg 5, --max-duration 200 | -| fast beam search | 14.62 | 14.36 | --epoch 18, --avg 5, --max-duration 200| - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - - - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --max-duration 300 \ - --num-buckets 50 -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/YyNv45pfQ0GqWzZ898WOlw/#scalars - -The decoding command is: -``` -epoch=18 -avg=5 -for method in greedy_search modified_beam_search fast_beam_search; do - ./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --beam-size 10 \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 200 \ - --decoding-method $method \ - --max-sym-per-frame 1 \ - --num-encoder-layers 12 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --use-averaged-model True -done -``` - -### MGB2 all data BPE training results (Conformer-CTC) (after 40 epochs) - -#### 2022-06-04 - -You can find a pretrained model, training logs, decoding logs, and decoding results at: -https://huggingface.co/AmirHussein/icefall-asr-mgb2-conformer_ctc-2022-27-06 - -The best WER, as of 2022-06-04, for the MGB2 test dataset is below - -Using whole lattice HLG decoding + n-gram LM rescoring - -| | dev | test | -|-----|------------|------------| -| WER | 15.62 | 15.01 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.1 | - | - - -Using n-best (n=0.5) attention decoder rescoring - -| | dev | test | -|-----|------------|------------| -| WER | 15.89 | 15.08 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.01 | 0.5 | - - -To reproduce the above result, use the following commands for training: - -# Note: the model was trained on V-100 32GB GPU - -``` -cd egs/mgb2/ASR -. ./path.sh -./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1" -./conformer_ctc/train.py \ - --lang-dir data/lang_bpe_5000 \ - --att-rate 0.8 \ - --lr-factor 10 \ - --max-duration \ - --concatenate-cuts 0 \ - --world-size 2 \ - --bucketing-sampler 1 \ - --max-duration 100 \ - --start-epoch 0 \ - --num-epochs 40 - -``` - -and the following command for nbest decoding - -``` -./conformer_ctc/decode.py \ - --lang-dir data/lang_bpe_5000 \ - --max-duration 30 \ - --concatenate-cuts 0 \ - --bucketing-sampler 1 \ - --num-paths 1000 \ - --epoch 40 \ - --avg 5 \ - --method attention-decoder \ - --nbest-scale 0.5 -``` - -and the following command for whole-lattice decoding - -``` -./conformer_ctc/decode.py \ - --epoch 40 \ - --avg 5 \ - --exp-dir conformer_ctc/exp_5000_att0.8 \ - --lang-dir data/lang_bpe_5000 \ - --max-duration 30 \ - --concatenate-cuts 0 \ - --bucketing-sampler 1 \ - --num-paths 1000 \ - --method whole-lattice-rescoring -``` - - -The tensorboard log for training is available at -https://tensorboard.dev/experiment/QYNzOi52RwOX8yvtpl3hMw/#scalars - - -### MGB2 100h BPE training results (Conformer-CTC) (after 33 epochs) - -#### 2022-06-04 - -The best WER, as of 2022-06-04, for the MGB2 test dataset is below - -Using whole lattice HLG decoding + n-gram LM rescoring - -| | dev | test | -|-----|------------|------------| -| WER | 25.32 | 23.53 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.1 | - | - - -Using n-best (n=0.5) HLG decoding + n-gram LM rescoring + attention decoder rescoring: - -| | dev | test | -|-----|------------|------------| -| WER | 27.87 | 26.12 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.01 | 0.3 | - - -To reproduce the above result, use the following commands for training: - -# Note: the model was trained on V-100 32GB GPU - -``` -cd egs/mgb2/ASR -. ./path.sh -./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1" -./conformer_ctc/train.py \ - --lang-dir data/lang_bpe_5000 \ - --att-rate 0.8 \ - --lr-factor 10 \ - --max-duration \ - --concatenate-cuts 0 \ - --world-size 2 \ - --bucketing-sampler 1 \ - --max-duration 100 \ - --start-epoch 0 \ - --num-epochs 40 - -``` - -and the following command for nbest decoding - -``` -./conformer_ctc/decode.py \ - --lang-dir data/lang_bpe_5000 \ - --max-duration 30 \ - --concatenate-cuts 0 \ - --bucketing-sampler 1 \ - --num-paths 1000 \ - --epoch 40 \ - --avg 5 \ - --method attention-decoder \ - --nbest-scale 0.5 -``` - -and the following command for whole-lattice decoding - -``` -./conformer_ctc/decode.py \ - --lang-dir data/lang_bpe_5000 \ - --max-duration 30 \ - --concatenate-cuts 0 \ - --bucketing-sampler 1 \ - --num-paths 1000 \ - --epoch 40 \ - --avg 5 \ - --method whole-lattice-rescoring -``` - -The tensorboard log for training is available at - - - - - diff --git a/egs/mgb2/ASR/conformer_ctc/__init__.py b/egs/mgb2/ASR/conformer_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py deleted file mode 100644 index 48921d71f..000000000 --- a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright 2022 Johns Hopkins University (Amir Hussein) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class MGB2AsrDataModule: - - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=1, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") - - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - - return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - - return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") diff --git a/egs/mgb2/ASR/conformer_ctc/compile_hlg.py b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/mgb2/ASR/conformer_ctc/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/conformer.py b/egs/mgb2/ASR/conformer_ctc/conformer.py deleted file mode 120000 index d1f4209d7..000000000 --- a/egs/mgb2/ASR/conformer_ctc/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/decode.py b/egs/mgb2/ASR/conformer_ctc/decode.py deleted file mode 100755 index f771d7f1e..000000000 --- a/egs/mgb2/ASR/conformer_ctc/decode.py +++ /dev/null @@ -1,695 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import pdb -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import MGB2AsrDataModule -from conformer import Conformer - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=50, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="attention-decoder", - help="""Decoding method. - Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - - (5) attention-decoder. Extract n paths from the LM rescored - lattice, the path with the highest score is the decoding result. - - (6) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=20, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="The lang dir", - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "subsampling_factor": 4, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - - nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - elif params.method == "attention-decoder": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=None, - ) - # TODO: pass `lattice` instead of `rescored_lattice` to - # `rescore_with_attention_decoder` - - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - else: - assert False, f"Unsupported decoding method: {params.method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - word_table: - It is the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - # pdb.set_trace() - texts = batch["supervisions"]["text"] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - if hyps_dict is not None: - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - else: - assert len(results) > 0, "It should not decode to empty in the first batch!" - this_batch = [] - hyp_words = [] - for ref_text in texts: - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - for lm_scale in results.keys(): - results[lm_scale].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - if params.method == "attention-decoder": - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - MGB2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id - - if params.method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - MGB2 = MGB2AsrDataModule(args) - - test_cuts = MGB2.test_cuts() - dev_cuts = MGB2.dev_cuts() - - test_dl = MGB2.test_dataloaders(test_cuts) - dev_dl = MGB2.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_all_dl = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_all_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/mgb2/ASR/conformer_ctc/download_lm.py b/egs/mgb2/ASR/conformer_ctc/download_lm.py deleted file mode 120000 index c9668bd2d..000000000 --- a/egs/mgb2/ASR/conformer_ctc/download_lm.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/download_lm.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/export.py b/egs/mgb2/ASR/conformer_ctc/export.py deleted file mode 120000 index 60e314d9d..000000000 --- a/egs/mgb2/ASR/conformer_ctc/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/export.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py deleted file mode 120000 index c0aea1403..000000000 --- a/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/mgb2/ASR/conformer_ctc/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/subsampling.py b/egs/mgb2/ASR/conformer_ctc/subsampling.py deleted file mode 120000 index 16354dc73..000000000 --- a/egs/mgb2/ASR/conformer_ctc/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py deleted file mode 120000 index 04b959ecf..000000000 --- a/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/test_label_smoothing.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/test_subsampling.py b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py deleted file mode 120000 index 98c3be3e6..000000000 --- a/egs/mgb2/ASR/conformer_ctc/test_subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/test_subsampling.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/test_transformer.py b/egs/mgb2/ASR/conformer_ctc/test_transformer.py deleted file mode 120000 index 8b0990ec6..000000000 --- a/egs/mgb2/ASR/conformer_ctc/test_transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/train.py b/egs/mgb2/ASR/conformer_ctc/train.py deleted file mode 100755 index 08ffee210..000000000 --- a/egs/mgb2/ASR/conformer_ctc/train.py +++ /dev/null @@ -1,766 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (Amir Hussein) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import MGB2AsrDataModule -from conformer import Conformer -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=50, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - conformer_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.8, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--num-decoder-layers", - type=int, - default=6, - help="""Number of decoder layer of transformer decoder. - Setting this to 0 will not create the decoder at all (pure CTC model) - """, - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - use_feat_batchnorm: Normalization for the input features, can be a - boolean indicating whether to do batch - normalization, or a float which means just scaling - the input features with this float value. - If given a float value, we will remove batchnorm - layer in `ConvolutionModule` as well. - - - attention_dim: Hidden dim for multi-head attention model. - - - head: Number of heads of multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - weight_decay: The weight_decay for the optimizer. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, - "num_decoder_layers": 6, - # parameters for loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for Noam - "weight_decay": 1e-6, - "warm_step": 80000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: BpeCtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction="none", - use_double_scores=params.use_double_scores, - ) - # filter inf from ctc_loss - ctc_loss = torch.sum( - torch.where( - ctc_loss != float("inf"), - ctc_loss, - torch.tensor(0, dtype=torch.float32).to(device), - ) - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) - - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: BpeCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - # if tot_loss is None: - # logging.warning("Batch mismatch. Skipping ...") - # del batch - # del tot_loss - # continue; - # elif tot_loss.isinf() or tot_loss.isnan(): - # logging.warning("NaN or Inf loss. Skipping ...") - # del batch - # del tot_loss - # continue; - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - else: - logging.warning( - f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..." - ) - continue - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - MGB2 = MGB2AsrDataModule(args) - - train_cuts = MGB2.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 0.5 <= c.duration <= 30.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = MGB2.train_dataloaders(train_cuts) - - valid_cuts = MGB2.dev_cuts() - valid_dl = MGB2.test_dataloaders(valid_cuts) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - MGB2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/mgb2/ASR/conformer_ctc/transformer.py b/egs/mgb2/ASR/conformer_ctc/transformer.py deleted file mode 120000 index 1c3f43fcf..000000000 --- a/egs/mgb2/ASR/conformer_ctc/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/__init__.py b/egs/mgb2/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/mgb2/ASR/local/compile_hlg.py b/egs/mgb2/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/mgb2/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/compute_fbank_mgb2.py b/egs/mgb2/ASR/local/compute_fbank_mgb2.py deleted file mode 100755 index 6cae69e41..000000000 --- a/egs/mgb2/ASR/local/compute_fbank_mgb2.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the MGB2 dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_mgb2(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "train", - "test", - "dev", - ) - manifests = read_manifests_if_cached( - prefix="mgb2", dataset_parts=dataset_parts, output_dir=src_dir - ) - assert manifests is not None - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"cuts_{partition}.json.gz").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - logging.info("About to split cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_mgb2() diff --git a/egs/mgb2/ASR/local/compute_fbank_musan.py b/egs/mgb2/ASR/local/compute_fbank_musan.py deleted file mode 100755 index 5d0d69a13..000000000 --- a/egs/mgb2/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - ChunkedLilcomHdf5Writer, - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - combine, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - prefix=prefix, - dataset_parts=dataset_parts, - output_dir=src_dir, - suffix=suffix, - ) - assert manifests is not None - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - ) - - musan_cuts_path = output_dir / "cuts_musan.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/feats_musan", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - ) - musan_cuts.to_file(musan_cuts_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() diff --git a/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 100755 index a8d5117c9..000000000 --- a/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -""" -Convert a transcript file containing words to a corpus file containing tokens -for LM training with the help of a lexicon. - -If the lexicon contains phones, the resulting LM will be a phone LM; If the -lexicon contains word pieces, the resulting LM will be a word piece LM. - -If a word has multiple pronunciations, the one that appears first in the lexicon -is kept; others are removed. - -If the input transcript is: - - hello zoo world hello - world zoo - foo zoo world hellO - -and if the lexicon is - - SPN - hello h e l l o 2 - hello h e l l o - world w o r l d - zoo z o o - -Then the output is - - h e l l o 2 z o o w o r l d h e l l o 2 - w o r l d z o o - SPN z o o w o r l d SPN -""" - -import argparse -from pathlib import Path -from typing import Dict, List - -from generate_unique_lexicon import filter_multiple_pronunications - -from icefall.lexicon import read_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--transcript", - type=str, - help="The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words.", - ) - parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument("--oov", type=str, default="", help="The OOV word.") - - return parser.parse_args() - - -def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: - """ - Args: - lexicon: - A dict containing pronunciations. Its keys are words and values - are pronunciations (i.e., tokens). - line: - A line of transcript consisting of space(s) separated words. - oov_token: - The pronunciation of the oov word if a word in `line` is not present - in the lexicon. - Returns: - Return None. - """ - s = "" - words = line.strip().split() - for i, w in enumerate(words): - tokens = lexicon.get(w, oov_token) - s += " ".join(tokens) - s += " " - print(s.strip()) - - -def main(): - args = get_args() - assert Path(args.lexicon).is_file() - assert Path(args.transcript).is_file() - assert len(args.oov) > 0 - - # Only the first pronunciation of a word is kept - lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) - - lexicon = dict(lexicon) - - assert args.oov in lexicon - - oov_token = lexicon[args.oov] - - with open(args.transcript) as f: - for line in f: - process_line(lexicon=lexicon, line=line, oov_token=oov_token) - - -if __name__ == "__main__": - main() diff --git a/egs/mgb2/ASR/local/display_manifest_statistics.py b/egs/mgb2/ASR/local/display_manifest_statistics.py deleted file mode 100755 index d3e224905..000000000 --- a/egs/mgb2/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer/train.py -for usage. -""" - - -from lhotse import load_manifest - - -def main(): - # path = "./data/fbank/cuts_train.jsonl.gz" - path = "./data/fbank/cuts_dev.jsonl.gz" - # path = "./data/fbank/cuts_test.jsonl.gz" - - cuts = load_manifest(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -# train - -Cuts count: 1125309 -Total duration (hours): 3403.9 -Speech duration (hours): 3403.9 (100.0%) -*** -Duration statistics (seconds): -mean 10.9 -std 10.1 -min 0.2 -25% 5.2 -50% 7.8 -75% 12.7 -99% 52.0 -99.5% 65.1 -99.9% 99.5 -max 228.9 - - -# test -Cuts count: 5365 -Total duration (hours): 9.6 -Speech duration (hours): 9.6 (100.0%) -*** -Duration statistics (seconds): -mean 6.4 -std 1.5 -min 1.6 -25% 5.3 -50% 6.5 -75% 7.6 -99% 9.5 -99.5% 9.7 -99.9% 10.3 -max 12.4 - -# dev -Cuts count: 5002 -Total duration (hours): 8.5 -Speech duration (hours): 8.5 (100.0%) -*** -Duration statistics (seconds): -mean 6.1 -std 1.7 -min 1.5 -25% 4.8 -50% 6.2 -75% 7.4 -99% 9.5 -99.5% 9.7 -99.9% 10.1 -max 20.3 - -""" diff --git a/egs/mgb2/ASR/local/generate_unique_lexicon.py b/egs/mgb2/ASR/local/generate_unique_lexicon.py deleted file mode 120000 index c0aea1403..000000000 --- a/egs/mgb2/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh deleted file mode 100755 index 3b673db6f..000000000 --- a/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2022 QCRI (author: Amir Hussein) -# Apache 2.0 -# This script prepares the graphemic lexicon. - -dir=data/local/dict -lexicon_url1="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_grapheme_lexicon_20160209.bz2"; -lexicon_url2="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_phoneme_lexicon_20140317.bz2"; -stage=0 -lang_dir=download/lm -mkdir -p $lang_dir - -if [ $stage -le 0 ]; then - echo "$0: Downloading text for lexicon... $(date)." - wget --no-check-certificate -P $lang_dir $lexicon_url1 - wget --no-check-certificate -P $lang_dir $lexicon_url2 - bzcat $lang_dir/ar-ar_grapheme_lexicon_20160209.bz2 | sed '1,3d' | awk '{print $1}' > $lang_dir/grapheme_lexicon - bzcat $lang_dir/ar-ar_phoneme_lexicon_20140317.bz2 | sed '1,3d' | awk '{print $1}' >> $lang_dir/phoneme_lexicon - cat download/lm/train/text | cut -d ' ' -f 2- | tr -s " " "\n" | sort -u >> $lang_dir/uniq_words -fi - - -if [ $stage -le 0 ]; then - echo "$0: processing lexicon text and creating lexicon... $(date)." - # remove vowels and rare alef wasla - cat $lang_dir/uniq_words | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir/grapheme_lexicon.txt -fi - -echo "$0: Lexicon preparation succeeded" diff --git a/egs/mgb2/ASR/local/prepare_lang.py b/egs/mgb2/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/mgb2/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/prepare_lang_bpe.py b/egs/mgb2/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/mgb2/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py deleted file mode 100755 index 99e1fa34d..000000000 --- a/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2022 Amir Hussein -# Apache 2.0 - -# This script prepares givel a column of words lexicon. - -import argparse - - -def get_args(): - parser = argparse.ArgumentParser( - description="""Creates the list of characters and words in lexicon""" - ) - parser.add_argument("input", type=str, help="""Input list of words file""") - parser.add_argument("output", type=str, help="""output graphemic lexicon""") - args = parser.parse_args() - return args - - -def main(): - lex = {} - args = get_args() - with open(args.input, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - characters = list(line) - characters = " ".join(["V" if char == "*" else char for char in characters]) - lex[line] = characters - - with open(args.output, "w", encoding="utf-8") as fp: - for key in sorted(lex): - fp.write(key + " " + lex[key] + "\n") - - -if __name__ == "__main__": - main() diff --git a/egs/mgb2/ASR/local/test_prepare_lang.py b/egs/mgb2/ASR/local/test_prepare_lang.py deleted file mode 120000 index f0f864998..000000000 --- a/egs/mgb2/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/test_prepare_lang.py \ No newline at end of file diff --git a/egs/mgb2/ASR/prepare.sh b/egs/mgb2/ASR/prepare.sh deleted file mode 100755 index 4ea427371..000000000 --- a/egs/mgb2/ASR/prepare.sh +++ /dev/null @@ -1,234 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2022 Johns Hopkins University (Amir Hussein) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -set -eou pipefail -nj=30 -stage=7 -stop_stage=1000 - -# We assume dl_dir (download dir) contains the following -# directories and files. -# -# - $dl_dir/mgb2 -# -# You can download the data from -# -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -# -# Note: MGB2 is not available for direct -# download, however you can fill out the form and -# download it from https://arabicspeech.org/mgb2 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - 5000 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/MGB2, - # you can create a symlink - # - # ln -sfv /path/to/mgb2 $dl_dir/MGB2 - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare mgb2 manifest" - # We assume that you have downloaded the mgb2 corpus - # to $dl_dir/mgb2 - mkdir -p data/manifests - - lhotse prepare mgb2 $dl_dir/mgb2 data/manifests - -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for mgb2" - mkdir -p data/fbank - ./local/compute_fbank_mgb2.py - # shufling the data - gunzip -c data/fbank/cuts_train.jsonl.gz | shuf | gzip -c > data/fbank/cuts_train_shuf.jsonl.gz -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - mkdir -p data/fbank - ./local/compute_fbank_musan.py -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" - if [[ ! -e download/lm/train/text ]]; then - # export train text file to build grapheme lexicon - lhotse kaldi export \ - data/manifests/mgb2_recordings_train.jsonl.gz \ - data/manifests/mgb2_supervisions_train.jsonl.gz \ - download/lm/train - fi - - lang_dir=data/lang_phone - mkdir -p $lang_dir - ./local/prep_mgb2_lexicon.sh - python local/prepare_mgb2_lexicon.py $dl_dir/lm/grapheme_lexicon.txt $dl_dir/lm/lexicon.txt - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/lm/lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi -fi - - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - files=$( - find "$dl_dir/lm/train" -name "text" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- | sed -r '/^\s*$/d' - done > $lang_dir/transcript_words.txt - fi - - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - fi - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram P" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - > $lang_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/P.arpa - fi - - if [ ! -f $lang_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_dir/P.arpa > $lang_dir/P.fst.txt - fi - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text $lang_dir/transcript_words.txt \ - -lm $lang_dir/G.arpa - - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $lang_dir/G.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - ./shared/make_kn_lm.py \ - -ngram-order 4 \ - -text $lang_dir/transcript_words.txt \ - -lm $lang_dir/4-gram.arpa - - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $lang_dir/4-gram.arpa > data/lm/G_4_gram.fst.txt - fi - done -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - done -fi diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py b/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py deleted file mode 120000 index a73848de9..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py deleted file mode 120000 index 02d01b343..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/beam_search.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py deleted file mode 120000 index c7c1a4b6e..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py deleted file mode 100755 index 72338bade..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py +++ /dev/null @@ -1,619 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins (authors: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 18 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 200 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless5/decode.py \ - --epoch 18 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 200 \ - --decoding-method beam_search \ - --beam-size 10 - -(3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 18 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 10 - -(4) fast beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 18 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 200 \ - --decoding-method fast_beam_search \ - --beam-size 10 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import MGB2AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - MGB2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - MGB2 = MGB2AsrDataModule(args) - - test_cuts = MGB2.test_cuts() - dev_cuts = MGB2.dev_cuts() - - test_dl = MGB2.test_dataloaders(test_cuts) - dev_dl = MGB2.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_all_dl = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_all_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py deleted file mode 120000 index 6775ee67e..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/decoder.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py deleted file mode 120000 index 972e44ca4..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/encoder_interface.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/export.py b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py deleted file mode 100755 index 7a5d7f680..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,272 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless5/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless5/decode.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - assert args.jit is False, "Support torchscript will be added later" - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py deleted file mode 120000 index f5279e151..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/joiner.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py deleted file mode 120000 index 7b417fd89..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/model.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py deleted file mode 120000 index 210374f22..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/optim.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py deleted file mode 100755 index 81a16f0ff..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py +++ /dev/null @@ -1,345 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by -./pruned_transducer_stateless5/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py deleted file mode 120000 index ff7bfeda9..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py deleted file mode 120000 index b71d7bb81..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/test_model.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py deleted file mode 100755 index 48468cfbd..000000000 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py +++ /dev/null @@ -1,1162 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins (authors: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 2 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --max-duration 200 \ - --num-buckets 50 - -# For mix precision training: - -./pruned_transducer_stateless5/train.py \ - --world-size 2 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --max-duration 200 \ - --num-buckets 50 - -""" - -# xxx -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import nvidia_smi -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import MGB2AsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=2048, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need " "to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 80000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, - reduction="none", -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - reduction="none", - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - inf_flag = False - if not torch.all(is_finite): - inf_flag = True - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - # info["utterances"] = feature.size(0) - # # averaged input duration in frames over utterances - # info["utt_duration"] = feature_lens.sum().item() - # # averaged padding proportion over utterances - # info["utt_pad_proportion"] = ( - # ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - # ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info, inf_flag - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info, inf_flag = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info, inf_flag = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - if not inf_flag: - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - else: - continue - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - memory_debugging() - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - else: - logging.warning( - f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..." - ) - continue - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def memory_debugging(): - # memory nvidia debugging - nvidia_smi.nvmlInit() - - deviceCount = nvidia_smi.nvmlDeviceGetCount() - for i in range(deviceCount): - handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i) - info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) - logging.info( - "Device {}: {}, Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format( - i, - nvidia_smi.nvmlDeviceGetName(handle), - 100 * info.free / info.total, - info.total, - info.free, - info.used, - ) - ) - - nvidia_smi.nvmlShutdown() - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - MGB2 = MGB2AsrDataModule(args) - train_cuts = MGB2.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 30 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 0.5 <= c.duration <= 30.0 - - def remove_short_and_long_text(c: Cut): - # Keep only text with charachters between 20 and 450 - - return 20 <= len(c.supervisions[0].text) <= 450 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.filter(remove_short_and_long_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) - - valid_cuts = MGB2.dev_cuts() - valid_dl = MGB2.test_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - # clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - - -def main(): - parser = get_parser() - MGB2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/mgb2/ASR/shared b/egs/mgb2/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/mgb2/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/README.md b/egs/multi_ja_en/ASR/README.md deleted file mode 100644 index 09964a4ab..000000000 --- a/egs/multi_ja_en/ASR/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# Introduction - -A bilingual Japanese-English ASR model that utilizes ReazonSpeech, developed by the developers of ReazonSpeech. - -**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio. - - -# Included Training Sets - -1. LibriSpeech (English) -2. ReazonSpeech (Japanese) - -|Datset| Number of hours| URL| -|---|---:|---| -|**TOTAL**|35,960|---| -|LibriSpeech|960|https://www.openslr.org/12/| -|ReazonSpeech (all) |35,000|https://huggingface.co/datasets/reazon-research/reazonspeech| diff --git a/egs/multi_ja_en/ASR/RESULTS.md b/egs/multi_ja_en/ASR/RESULTS.md deleted file mode 100644 index 0f6996013..000000000 --- a/egs/multi_ja_en/ASR/RESULTS.md +++ /dev/null @@ -1,52 +0,0 @@ -## Results - -### Zipformer - -#### Non-streaming - -The training command is: - -```shell -./zipformer/train.py \ - --bilingual 1 \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 600 -``` - -The decoding command is: - -```shell -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search -``` - -To export the model with onnx: - -```shell -./zipformer/export-onnx.py --tokens data/lang_bbpe_2000/tokens.txt --use-averaged-model 0 --epoch 35 --avg 1 --exp-dir zipformer/exp --num-encoder-layers "2,2,3,4,3,2" --downsampling-factor "1,2,4,8,4,2" --feedforward-dim "512,768,1024,1536,1024,768" --num-heads "4,4,4,8,4,4" --encoder-dim "192,256,384,512,384,256" --query-head-dim 32 --value-head-dim 12 --pos-head-dim 4 --pos-dim 48 --encoder-unmasked-dim "192,192,256,256,256,192" --cnn-module-kernel "31,31,15,15,15,31" --decoder-dim 512 --joiner-dim 512 --causal False --chunk-size "16,32,64,-1" --left-context-frames "64,128,256,-1" --fp16 True -``` -Word Error Rates (WERs) listed below: - -| Datasets | ReazonSpeech | ReazonSpeech | LibriSpeech | LibriSpeech | -|----------------------|--------------|---------------|--------------------|-------------------| -| Zipformer WER (%) | dev | test | test-clean | test-other | -| greedy_search | 5.9 | 4.07 | 3.46 | 8.35 | -| modified_beam_search | 4.87 | 3.61 | 3.28 | 8.07 | - - -Character Error Rates (CERs) for Japanese listed below: -| Decoding Method | In-Distribution CER | JSUT | CommonVoice | TEDx | -| :------------------: | :-----------------: | :--: | :---------: | :---: | -| greedy search | 12.56 | 6.93 | 9.75 | 9.67 | -| modified beam search | 11.59 | 6.97 | 9.55 | 9.51 | - -Pre-trained model can be found here: https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/main - diff --git a/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py b/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py deleted file mode 100644 index af7841406..000000000 --- a/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import os -from pathlib import Path -from typing import List, Tuple - -import torch - -# fmt: off -from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527 - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - RecordingSet, - SupervisionSet, -) - -# fmt: on - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -RNG_SEED = 42 -concat_params = {"gap": 1.0, "maxlen": 10.0} - - -def make_cutset_blueprints( - manifest_dir: Path, -) -> List[Tuple[str, CutSet]]: - cut_sets = [] - - # Create test dataset - logging.info("Creating test cuts.") - cut_sets.append( - ( - "test", - CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "reazonspeech_recordings_test.jsonl.gz" - ), - supervisions=SupervisionSet.from_file( - manifest_dir / "reazonspeech_supervisions_test.jsonl.gz" - ), - ), - ) - ) - - # Create dev dataset - logging.info("Creating dev cuts.") - cut_sets.append( - ( - "dev", - CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "reazonspeech_recordings_dev.jsonl.gz" - ), - supervisions=SupervisionSet.from_file( - manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz" - ), - ), - ) - ) - - # Create train dataset - logging.info("Creating train cuts.") - cut_sets.append( - ( - "train", - CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "reazonspeech_recordings_train.jsonl.gz" - ), - supervisions=SupervisionSet.from_file( - manifest_dir / "reazonspeech_supervisions_train.jsonl.gz" - ), - ), - ) - ) - return cut_sets - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("-m", "--manifest-dir", type=Path) - return parser.parse_args() - - -def main(): - args = get_args() - - extractor = Fbank(FbankConfig(num_mel_bins=80)) - num_jobs = min(16, os.cpu_count()) - - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - if (args.manifest_dir / ".reazonspeech-fbank.done").exists(): - logging.info( - "Previous fbank computed for ReazonSpeech found. " - f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank." - ) - return - else: - cut_sets = make_cutset_blueprints(args.manifest_dir) - for part, cut_set in cut_sets: - logging.info(f"Processing {part}") - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - num_jobs=num_jobs, - storage_path=(args.manifest_dir / f"feats_{part}").as_posix(), - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz") - - logging.info("All fbank computed for ReazonSpeech.") - (args.manifest_dir / ".reazonspeech-fbank.done").touch() - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/display_manifest_statistics.py b/egs/multi_ja_en/ASR/local/display_manifest_statistics.py deleted file mode 100644 index ace1dd73f..000000000 --- a/egs/multi_ja_en/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -from pathlib import Path - -from lhotse import CutSet, load_manifest - -ARGPARSE_DESCRIPTION = """ -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in -pruned_transducer_stateless5/train.py for usage. -""" - - -def get_parser(): - parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") - - return parser.parse_args() - - -def main(): - args = get_parser() - - for part in ["train", "dev"]: - path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz" - cuts: CutSet = load_manifest(path) - - print("\n---------------------------------\n") - print(path.name + ":") - cuts.describe() - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/prepare_char.py b/egs/multi_ja_en/ASR/local/prepare_char.py deleted file mode 120000 index 42743b544..000000000 --- a/egs/multi_ja_en/ASR/local/prepare_char.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_ja_en/ASR/local/prepare_for_bpe_model.py deleted file mode 100755 index 27832ad1b..000000000 --- a/egs/multi_ja_en/ASR/local/prepare_for_bpe_model.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script tokenizes the training transcript by CJK characters -# and saves the result to transcript_chars.txt, which is used -# to train the BPE model later. - -import argparse -import re -from pathlib import Path - -from tqdm.auto import tqdm - -from icefall.utils import tokenize_by_ja_char - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Output directory. - The generated transcript_chars.txt is saved to this directory. - """, - ) - - parser.add_argument( - "--text", - type=str, - help="Training transcript.", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - text = Path(args.text) - - assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" - - transcript_path = lang_dir / "transcript_chars.txt" - - with open(text, "r", encoding="utf-8") as fin: - with open(transcript_path, "w+", encoding="utf-8") as fout: - for line in tqdm(fin): - fout.write(tokenize_by_ja_char(line) + "\n") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/prepare_lang.py b/egs/multi_ja_en/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/multi_ja_en/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py deleted file mode 100755 index 6134710ad..000000000 --- a/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py +++ /dev/null @@ -1,268 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/bbpe.model, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -import re -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import sentencepiece as spm -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - -from icefall.byte_utils import byte_encode -from icefall.utils import str2bool, tokenize_by_ja_char - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def generate_lexicon( - model_file: str, words: List[str], oov: str -) -> Tuple[Lexicon, Dict[str, int]]: - """Generate a lexicon from a BPE model. - - Args: - model_file: - Path to a sentencepiece model. - words: - A list of strings representing words. - oov: - The out of vocabulary word in lexicon. - Returns: - Return a tuple with two elements: - - A dict whose keys are words and values are the corresponding - word pieces. - - A dict representing the token symbol, mapping from tokens to IDs. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - - # Convert word to word piece IDs instead of word piece strings - # to avoid OOV tokens. - encode_words = [byte_encode(tokenize_by_ja_char(w)) for w in words] - words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int) - - # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] - - lexicon = [] - for word, pieces in zip(words, words_pieces): - lexicon.append((word, pieces)) - - lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) - - token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} - - return lexicon, token2id - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the bpe.model and words.txt - """, - ) - - parser.add_argument( - "--oov", - type=str, - default="", - help="The out of vocabulary word in lexicon.", - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - - See "test/test_bpe_lexicon.py" for usage. - """, - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - model_file = lang_dir / "bbpe.model" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", args.oov, "#0", "", ""] - - for w in excluded: - if w in words: - words.remove(w) - - lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/prepare_lang_char.py b/egs/multi_ja_en/ASR/local/prepare_lang_char.py deleted file mode 100644 index 19c5f4a31..000000000 --- a/egs/multi_ja_en/ASR/local/prepare_lang_char.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help=( - "Name of lang dir. " - "If not set, this will default to lang_char_{trans-mode}" - ), - ) - - return parser.parse_args() - - -def main(): - args = get_args() - logging.basicConfig( - format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), - level=logging.INFO, - ) - - sysdef_string = set(["", "", "", " "]) - - token_set = set() - logging.info(f"Creating vocabulary from {args.train_cut}.") - train_cut: CutSet = CutSet.from_file(args.train_cut) - for cut in train_cut: - for sup in cut.supervisions: - token_set.update(sup.text) - - token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] - args.lang_dir.mkdir(parents=True, exist_ok=True) - (args.lang_dir / "tokens.txt").write_text( - "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) - ) - - (args.lang_dir / "lang_type").write_text("char") - logging.info("Done.") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/prepare_words.py b/egs/multi_ja_en/ASR/local/prepare_words.py deleted file mode 120000 index ef2b4eaf3..000000000 --- a/egs/multi_ja_en/ASR/local/prepare_words.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell2/ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/text2segments.py b/egs/multi_ja_en/ASR/local/text2segments.py deleted file mode 100644 index e0f3a15c4..000000000 --- a/egs/multi_ja_en/ASR/local/text2segments.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# 2022 Xiaomi Corp. (authors: Weiji Zhuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text", which refers to the transcript file: - - text -and generates the output file with word segmentation implemented using MeCab: - - text_words_segmentation -""" - -import argparse -from multiprocessing import Pool - -import MeCab -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Japanese Word Segmentation for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--num-process", - "-n", - default=20, - type=int, - help="the number of processes", - ) - parser.add_argument( - "--input-file", - "-i", - default="data/lang_char/text", - type=str, - help="the input text file", - ) - parser.add_argument( - "--output-file", - "-o", - default="data/lang_char/text_words_segmentation", - type=str, - help="the text implemented with word segmentation using MeCab", - ) - - return parser - - -def cut(lines): - if lines is not None: - mecab = MeCab.Tagger("-Owakati") # Use '-Owakati' option for word segmentation - segmented_line = mecab.parse(lines).strip() - return segmented_line.split() # Return as a list of words - else: - return None - - -def main(): - parser = get_parser() - args = parser.parse_args() - - num_process = args.num_process - input_file = args.input_file - output_file = args.output_file - - with open(input_file, "r", encoding="utf-8") as fr: - lines = fr.readlines() - - with Pool(processes=num_process) as p: - new_lines = list(tqdm(p.imap(cut, lines), total=len(lines))) - - with open(output_file, "w", encoding="utf-8") as fw: - for line in new_lines: - fw.write(" ".join(line) + "\n") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/text2token.py b/egs/multi_ja_en/ASR/local/text2token.py deleted file mode 100755 index ce64847c9..000000000 --- a/egs/multi_ja_en/ASR/local/text2token.py +++ /dev/null @@ -1,177 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) -# 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import codecs -import re -import sys -from typing import List - -from romkan import to_roma # Replace with python-romkan v0.2.1 - -is_python2 = sys.version_info[0] == 2 - - -def exist_or_not(i, match_pos): - start_pos = None - end_pos = None - for pos in match_pos: - if pos[0] <= i < pos[1]: - start_pos = pos[0] - end_pos = pos[1] - break - - return start_pos, end_pos - - -def get_parser(): - parser = argparse.ArgumentParser( - description="convert raw text to tokenized text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--nchar", - "-n", - default=1, - type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", - ) - parser.add_argument( - "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" - ) - parser.add_argument("--space", default="", type=str, help="space symbol") - parser.add_argument( - "--non-lang-syms", - "-l", - default=None, - type=str, - help="list of non-linguistic symbols, e.g., etc.", - ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") - parser.add_argument( - "--trans_type", - "-t", - type=str, - default="char", - choices=["char", "romaji"], - help="Transcript type. char/romaji", - ) - return parser - - -def token2id( - texts, token_table, token_type: str = "romaji", oov: str = "" -) -> List[List[int]]: - """Convert token to id. - Args: - texts: - The input texts, it refers to the Japanese text here. - token_table: - The token table is built based on "data/lang_xxx/token.txt" - token_type: - The type of token, such as "romaji". - oov: - Out of vocabulary token. When a word(token) in the transcript - does not exist in the token list, it is replaced with `oov`. - - Returns: - The list of ids for the input texts. - """ - if texts is None: - raise ValueError("texts can't be None!") - else: - oov_id = token_table[oov] - ids: List[List[int]] = [] - for text in texts: - chars_list = list(str(text)) - if token_type == "romaji": - text = [to_roma(c) for c in chars_list] - sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text - ] - ids.append(sub_ids) - return ids - - -def main(): - parser = get_parser() - args = parser.parse_args() - - rs = [] - if args.non_lang_syms is not None: - with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: - nls = [x.rstrip() for x in f.readlines()] - rs = [re.compile(re.escape(x)) for x in nls] - - if args.text: - f = codecs.open(args.text, encoding="utf-8") - else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) - - sys.stdout = codecs.getwriter("utf-8")( - sys.stdout if is_python2 else sys.stdout.buffer - ) - line = f.readline() - n = args.nchar - while line: - x = line.split() - print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols :]) # noqa E203 - - # get all matched positions - match_pos = [] - for r in rs: - i = 0 - while i >= 0: - m = r.search(a, i) - if m: - match_pos.append([m.start(), m.end()]) - i = m.end() - else: - break - if len(match_pos) > 0: - chars = [] - i = 0 - while i < len(a): - start_pos, end_pos = exist_or_not(i, match_pos) - if start_pos is not None: - chars.append(a[start_pos:end_pos]) - i = end_pos - else: - chars.append(a[i]) - i += 1 - a = chars - - if args.trans_type == "romaji": - a = [to_roma(c) for c in list(str(a))] - - a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203 - - a_flat = [] - for z in a: - a_flat.append("".join(z)) - - a_chars = "".join(a_flat) - print(a_chars) - line = f.readline() - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/train_bbpe_model.py b/egs/multi_ja_en/ASR/local/train_bbpe_model.py deleted file mode 100755 index d104f2717..000000000 --- a/egs/multi_ja_en/ASR/local/train_bbpe_model.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import re -import shutil -import tempfile -from pathlib import Path - -import sentencepiece as spm - -from icefall import byte_encode -from icefall.utils import tokenize_by_ja_char - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def _convert_to_bchar(in_path: str, out_path: str): - with open(out_path, "w") as f: - for line in open(in_path, "r").readlines(): - f.write(byte_encode(tokenize_by_ja_char(line)) + "\n") - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - model_file = Path(model_prefix + ".model") - if model_file.is_file(): - print(f"{model_file} exists - skipping") - return - - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - temp = tempfile.NamedTemporaryFile() - train_text = temp.name - - _convert_to_bchar(args.transcript, train_text) - - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - - shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py deleted file mode 100644 index be18e65c1..000000000 --- a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class ReazonSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/dev/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=False, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - - transforms = [] - input_transforms = [] - - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" - ) diff --git a/egs/multi_ja_en/ASR/local/utils/tokenizer.py b/egs/multi_ja_en/ASR/local/utils/tokenizer.py deleted file mode 100644 index ba71cff89..000000000 --- a/egs/multi_ja_en/ASR/local/utils/tokenizer.py +++ /dev/null @@ -1,252 +0,0 @@ -import argparse -from pathlib import Path -from typing import Callable, List, Union - -import sentencepiece as spm -from k2 import SymbolTable - - -class Tokenizer: - text2word: Callable[[str], List[str]] - - @staticmethod - def add_arguments(parser: argparse.ArgumentParser): - group = parser.add_argument_group(title="Lang related options") - group.add_argument("--lang", type=Path, help="Path to lang directory.") - - group.add_argument( - "--lang-type", - type=str, - default=None, - help=( - "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " - "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" - ), - ) - - @staticmethod - def Load(lang_dir: Path, lang_type="", oov=""): - - if not lang_type: - assert (lang_dir / "lang_type").exists(), "lang_type not specified." - lang_type = (lang_dir / "lang_type").read_text().strip() - - tokenizer = None - - if lang_type == "bpe": - assert ( - lang_dir / "bpe.model" - ).exists(), f"No BPE .model could be found in {lang_dir}." - tokenizer = spm.SentencePieceProcessor() - tokenizer.Load(str(lang_dir / "bpe.model")) - elif lang_type == "char": - tokenizer = CharTokenizer(lang_dir, oov=oov) - else: - raise NotImplementedError(f"{lang_type} not supported at the moment.") - - return tokenizer - - load = Load - - def PieceToId(self, piece: str) -> int: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - piece_to_id = PieceToId - - def IdToPiece(self, id: int) -> str: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - id_to_piece = IdToPiece - - def GetPieceSize(self) -> int: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - get_piece_size = GetPieceSize - - def __len__(self) -> int: - return self.get_piece_size() - - def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def EncodeAsIds(self, input: str) -> List[int]: - return self.EncodeAsIdsBatch([input])[0] - - def EncodeAsPieces(self, input: str) -> List[str]: - return self.EncodeAsPiecesBatch([input])[0] - - def Encode( - self, input: Union[str, List[str]], out_type=int - ) -> Union[List, List[List]]: - if not input: - return [] - - if isinstance(input, list): - if out_type is int: - return self.EncodeAsIdsBatch(input) - if out_type is str: - return self.EncodeAsPiecesBatch(input) - - if out_type is int: - return self.EncodeAsIds(input) - if out_type is str: - return self.EncodeAsPieces(input) - - encode = Encode - - def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def DecodeIds(self, input: List[int]) -> str: - return self.DecodeIdsBatch([input])[0] - - def DecodePieces(self, input: List[str]) -> str: - return self.DecodePiecesBatch([input])[0] - - def Decode( - self, - input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], - ) -> Union[List[str], str]: - - if not input: - return "" - - if isinstance(input, int): - return self.id_to_piece(input) - elif isinstance(input, str): - raise TypeError( - "Unlike spm.SentencePieceProcessor, cannot decode from type str." - ) - - if isinstance(input[0], list): - if not input[0] or isinstance(input[0][0], int): - return self.DecodeIdsBatch(input) - - if isinstance(input[0][0], str): - return self.DecodePiecesBatch(input) - - if isinstance(input[0], int): - return self.DecodeIds(input) - if isinstance(input[0], str): - return self.DecodePieces(input) - - raise RuntimeError("Unknown input type") - - decode = Decode - - def SplitBatch(self, input: List[str]) -> List[List[str]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: - if isinstance(input, list): - return self.SplitBatch(input) - elif isinstance(input, str): - return self.SplitBatch([input])[0] - raise RuntimeError("Unknown input type") - - split = Split - - -class CharTokenizer(Tokenizer): - def __init__(self, lang_dir: Path, oov="", sep=""): - assert ( - lang_dir / "tokens.txt" - ).exists(), f"tokens.txt could not be found in {lang_dir}." - token_table = SymbolTable.from_file(lang_dir / "tokens.txt") - assert ( - "#0" not in token_table - ), "This tokenizer does not support disambig symbols." - self._id2sym = token_table._id2sym - self._sym2id = token_table._sym2id - self.oov = oov - self.oov_id = self._sym2id[oov] - self.sep = sep - if self.sep: - self.text2word = lambda x: x.split(self.sep) - else: - self.text2word = lambda x: list(x.replace(" ", "")) - - def piece_to_id(self, piece: str) -> int: - try: - return self._sym2id[piece] - except KeyError: - return self.oov_id - - def id_to_piece(self, id: int) -> str: - return self._id2sym[id] - - def get_piece_size(self) -> int: - return len(self._sym2id) - - def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: - return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] - - def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: - return [ - [i if i in self._sym2id else self.oov for i in self.text2word(text)] - for text in input - ] - - def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: - return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] - - def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: - return [self.sep.join(text) for text in input] - - def SplitBatch(self, input: List[str]) -> List[List[str]]: - return [self.text2word(text) for text in input] - - -def test_CharTokenizer(): - test_single_string = "こんにちは" - test_multiple_string = [ - "今日はいい天気ですよね", - "諏訪湖は綺麗でしょう", - "这在词表外", - "分かち 書き に し た 文章 です", - "", - ] - test_empty_string = "" - sp = Tokenizer.load(Path("lang_char"), "char", oov="") - splitter = sp.split - print(sp.encode(test_single_string, out_type=str)) - print(sp.encode(test_single_string, out_type=int)) - print(sp.encode(test_multiple_string, out_type=str)) - print(sp.encode(test_multiple_string, out_type=int)) - print(sp.encode(test_empty_string, out_type=str)) - print(sp.encode(test_empty_string, out_type=int)) - print(sp.decode(sp.encode(test_single_string, out_type=str))) - print(sp.decode(sp.encode(test_single_string, out_type=int))) - print(sp.decode(sp.encode(test_multiple_string, out_type=str))) - print(sp.decode(sp.encode(test_multiple_string, out_type=int))) - print(sp.decode(sp.encode(test_empty_string, out_type=str))) - print(sp.decode(sp.encode(test_empty_string, out_type=int))) - print(splitter(test_single_string)) - print(splitter(test_multiple_string)) - print(splitter(test_empty_string)) - - -if __name__ == "__main__": - test_CharTokenizer() diff --git a/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/validate_manifest.py b/egs/multi_ja_en/ASR/local/validate_manifest.py deleted file mode 100644 index 7f67c64b6..000000000 --- a/egs/multi_ja_en/ASR/local/validate_manifest.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut -- Supervision time bounds are within cut time bounds - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def validate_one_supervision_per_cut(c: Cut): - if len(c.supervisions) != 1: - raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") - - -def validate_supervision_and_cut_time_bounds(c: Cut): - s = c.supervisions[0] - - # Removed because when the cuts were trimmed from supervisions, - # the start time of the supervision can be lesser than cut start time. - # https://github.com/lhotse-speech/lhotse/issues/813 - # if s.start < c.start: - # raise ValueError( - # f"{c.id}: Supervision start time {s.start} is less " - # f"than cut start time {c.start}" - # ) - - if s.end > c.end: - raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" - ) - - -def main(): - args = get_args() - - manifest = Path(args.manifest) - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest(manifest) - assert isinstance(cut_set, CutSet) - - for c in cut_set: - validate_one_supervision_per_cut(c) - validate_supervision_and_cut_time_bounds(c) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/multi_ja_en/ASR/prepare.sh b/egs/multi_ja_en/ASR/prepare.sh deleted file mode 100755 index 7a6a63418..000000000 --- a/egs/multi_ja_en/ASR/prepare.sh +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -vocab_sizes=( - 2000 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -log "Dataset: musan" -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Soft link fbank of musan" - mkdir -p data/fbank - if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" - exit 1 - fi -fi - -log "Dataset: LibriSpeech" -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 1: Soft link fbank of LibriSpeech" - mkdir -p data/fbank - if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) . - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) . - cd ../.. - else - log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 1 --stop-stage 1 and ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi -fi - -log "Dataset: ReazonSpeech" -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 2: Soft link fbank of ReazonSpeech" - mkdir -p data/fbank - if [ -e ../../reazonspeech/ASR/data/manifests/.reazonspeech.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/reazonspeech_cuts*) . - cd .. - mkdir -p manifests - cd manifests - ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/feats_*) . - cd ../.. - else - log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 0 --stop-stage 2" - exit 1 - fi -fi - -# New Stage 3: Prepare char based lang for ReazonSpeech -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - lang_char_dir=data/lang_char - log "Stage 3: Prepare char based lang for ReazonSpeech" - mkdir -p $lang_char_dir - - # Prepare text - if [ ! -f $lang_char_dir/text ]; then - gunzip -c ../../reazonspeech/ASR/data/manifests/reazonspeech_supervisions_train.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text - fi - - # jp word segmentation for text - if [ ! -f $lang_char_dir/text_words_segmentation ]; then - python3 ./local/text2segments.py \ - --input-file $lang_char_dir/text \ - --output-file $lang_char_dir/text_words_segmentation - fi - - cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt - - if [ ! -f $lang_char_dir/words.txt ]; then - python3 ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt - fi - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - python3 ./local/prepare_char.py --lang-dir data/lang_char - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare Byte BPE based lang" - mkdir -p data/fbank - if [ ! -d ../../reazonspeech/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then - log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi - - if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then - log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5" - exit 1 - fi - - cd data/ - # if [ ! -d ./lang_char ]; then - # ln -svf $(realpath ../../../reazonspeech/ASR/data/lang_char) . - # fi - if [ ! -d ./lang_bpe_500 ]; then - ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . - fi - cd ../ - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bbpe_${vocab_size} - mkdir -p $lang_dir - - cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ - > $lang_dir/text - - if [ ! -f $lang_dir/transcript_chars.txt ]; then - ./local/prepare_for_bpe_model.py \ - --lang-dir ./$lang_dir \ - --text $lang_dir/text - fi - - if [ ! -f $lang_dir/text_words_segmentation ]; then - python3 ./local/text2segments.py \ - --input-file ./data/lang_char/text \ - --output-file $lang_dir/text_words_segmentation - - cat ./data/lang_bpe_500/transcript_words.txt \ - >> $lang_dir/text_words_segmentation - fi - - cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt - - if [ ! -f $lang_dir/words.txt ]; then - python3 ./local/prepare_words.py \ - --input-file $lang_dir/words_no_ids.txt \ - --output-file $lang_dir/words.txt - fi - - if [ ! -f $lang_dir/bbpe.model ]; then - ./local/train_bbpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/text - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bbpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ln -svf $(realpath ../../multi_zh_en/ASR/local/validate_bpe_lexicon.py) local/ - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bbpe.model - fi - done -fi - -log "prepare.sh: PREPARATION DONE" diff --git a/egs/multi_ja_en/ASR/shared b/egs/multi_ja_en/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/multi_ja_en/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/asr_datamodule.py b/egs/multi_ja_en/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index a48591198..000000000 --- a/egs/multi_ja_en/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/beam_search.py b/egs/multi_ja_en/ASR/zipformer/beam_search.py deleted file mode 120000 index 8e2c0a65c..000000000 --- a/egs/multi_ja_en/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/ctc_decode.py b/egs/multi_ja_en/ASR/zipformer/ctc_decode.py deleted file mode 120000 index faa8bd562..000000000 --- a/egs/multi_ja_en/ASR/zipformer/ctc_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/decode.py b/egs/multi_ja_en/ASR/zipformer/decode.py deleted file mode 100755 index 26ce3e018..000000000 --- a/egs/multi_ja_en/ASR/zipformer/decode.py +++ /dev/null @@ -1,792 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - -import argparse -import logging -import math -import re -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params - -from icefall import byte_encode, smart_byte_decode -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - tokenize_by_ja_char, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_2000/bbpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bbpe_2000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(byte_encode(tokenize_by_ja_char(supervisions["text"]))), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [tokenize_by_ja_char(str(text)).split() for text in texts] - # print(texts) - # exit() - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - data_module = ReazonSpeechAsrDataModule(args) - multi_dataset = MultiDataset(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets_cuts = multi_dataset.test_cuts() - - test_sets = test_sets_cuts.keys() - test_dl = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dl): - logging.info(f"Start decoding test set: {test_set}") - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/zipformer/decode_stream.py b/egs/multi_ja_en/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/multi_ja_en/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/decoder.py b/egs/multi_ja_en/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/multi_ja_en/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py b/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py deleted file mode 100755 index 072679cfc..000000000 --- a/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py +++ /dev/null @@ -1,1261 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import math -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer_for_ncnn_export_only import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -LOG_EPS = math.log(1e-10) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=0, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - is_pnnx=True, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: # noqa - logging.error(e, exc_info=True) - display_and_save_batch(batch, params=params, sp=sp) - raise e - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - log_mode = logging.info - log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") - log_mode( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, master_port=params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 30.0: - logging.debug( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.info( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - train_cuts = reazonspeech_corpus.train_cuts() - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = reazonspeech_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = reazonspeech_corpus.valid_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) - - if params.start_batch <= 0 and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: Tokenizer, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - raise RuntimeError("Please don't use this file directly!") - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/zipformer/encoder_interface.py b/egs/multi_ja_en/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/multi_ja_en/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/export-onnx.py b/egs/multi_ja_en/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/multi_ja_en/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/export.py b/egs/multi_ja_en/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/multi_ja_en/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_ja_en/ASR/zipformer/generate_averaged_model.py deleted file mode 120000 index 5a015ee6c..000000000 --- a/egs/multi_ja_en/ASR/zipformer/generate_averaged_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/joiner.py b/egs/multi_ja_en/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/multi_ja_en/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/model.py b/egs/multi_ja_en/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/multi_ja_en/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/multi_dataset.py b/egs/multi_ja_en/ASR/zipformer/multi_dataset.py deleted file mode 100644 index b0cdc1f6a..000000000 --- a/egs/multi_ja_en/ASR/zipformer/multi_dataset.py +++ /dev/null @@ -1,143 +0,0 @@ -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Dict - -from lhotse import CutSet, load_manifest_lazy - - -class MultiDataset: - def __init__(self, args: argparse.Namespace): - """ - Args: - manifest_dir: - It is expected to contain the following files: - - reazonspeech_cuts_train.jsonl.gz - - librispeech_cuts_train-clean-100.jsonl.gz - - librispeech_cuts_train-clean-360.jsonl.gz - - librispeech_cuts_train-other-500.jsonl.gz - """ - self.fbank_dir = Path(args.manifest_dir) - - def train_cuts(self) -> CutSet: - logging.info("About to get multidataset train cuts") - - logging.info("Loading Reazonspeech in lazy mode") - reazonspeech_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz" - ) - - logging.info("Loading LibriSpeech in lazy mode") - train_clean_100_cuts = self.train_clean_100_cuts() - train_clean_360_cuts = self.train_clean_360_cuts() - train_other_500_cuts = self.train_other_500_cuts() - - return CutSet.mux( - reazonspeech_cuts, - train_clean_100_cuts, - train_clean_360_cuts, - train_other_500_cuts, - weights=[ - len(reazonspeech_cuts), - len(train_clean_100_cuts), - len(train_clean_360_cuts), - len(train_other_500_cuts), - ], - ) - - def dev_cuts(self) -> CutSet: - logging.info("About to get multidataset dev cuts") - - logging.info("Loading Reazonspeech DEV set in lazy mode") - reazonspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz" - ) - - logging.info("Loading LibriSpeech DEV set in lazy mode") - dev_clean_cuts = self.dev_clean_cuts() - dev_other_cuts = self.dev_other_cuts() - - return CutSet.mux( - reazonspeech_dev_cuts, - dev_clean_cuts, - dev_other_cuts, - weights=[ - len(reazonspeech_dev_cuts), - len(dev_clean_cuts), - len(dev_other_cuts), - ], - ) - - def test_cuts(self) -> Dict[str, CutSet]: - logging.info("About to get multidataset test cuts") - - logging.info("Loading Reazonspeech set in lazy mode") - reazonspeech_test_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz" - ) - reazonspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz" - ) - - logging.info("Loading LibriSpeech set in lazy mode") - test_clean_cuts = self.test_clean_cuts() - test_other_cuts = self.test_other_cuts() - - test_cuts = { - "reazonspeech_test": reazonspeech_test_cuts, - "reazonspeech_dev": reazonspeech_dev_cuts, - "librispeech_test_clean": test_clean_cuts, - "librispeech_test_other": test_other_cuts, - } - - return test_cuts - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" - ) diff --git a/egs/multi_ja_en/ASR/zipformer/my_profile.py b/egs/multi_ja_en/ASR/zipformer/my_profile.py deleted file mode 120000 index 3a90b2628..000000000 --- a/egs/multi_ja_en/ASR/zipformer/my_profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/onnx_decode.py b/egs/multi_ja_en/ASR/zipformer/onnx_decode.py deleted file mode 120000 index 0573b88c5..000000000 --- a/egs/multi_ja_en/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_ja_en/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/multi_ja_en/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/optim.py b/egs/multi_ja_en/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/multi_ja_en/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/pretrained.py b/egs/multi_ja_en/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/multi_ja_en/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/scaling.py b/egs/multi_ja_en/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/multi_ja_en/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/scaling_converter.py b/egs/multi_ja_en/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/multi_ja_en/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_ja_en/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/multi_ja_en/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/streaming_decode.py b/egs/multi_ja_en/ASR/zipformer/streaming_decode.py deleted file mode 100755 index 935f86de1..000000000 --- a/egs/multi_ja_en/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,935 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - -Monolingual: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp-large \ - --lang data/lang_char \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 - -Bilingual: -./zipformer/streaming_decode.py \ - --bilingual 1 \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp-large \ - --lang data/lang_char \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - -""" - -import argparse -import logging -import math -import os -import pdb -import subprocess as sp -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import ReazonSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from tokenizer import Tokenizer -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--bilingual", - type=str2bool, - default=False, - help="Whether the model is bilingual or not. 1 = bilingual.", - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=model.device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=model.device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - # finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - - if not finished_streams: - print("No finished streams, breaking the loop") - break - - for i in sorted(finished_streams, reverse=True): - try: - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - except IndexError as e: - print(f"IndexError: {e}") - print(f"decode_streams length: {len(decode_streams)}") - print(f"finished_streams: {finished_streams}") - print(f"i: {i}") - continue - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - torch.cuda.synchronize() - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - if not params.bilingual: - sp = Tokenizer.load(params.lang, params.lang_type) - else: - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - - if params.bilingual: - multi_dataset = MultiDataset(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets_cuts = multi_dataset.test_cuts() - test_sets = test_sets_cuts.keys() - test_cuts = [test_sets_cuts[k] for k in test_sets] - - valid_cuts = reazonspeech_corpus.valid_cuts() - test_cuts = reazonspeech_corpus.test_cuts() - - test_sets = ["valid", "test"] - test_cuts = [valid_cuts, test_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - logging.info(f"Decoding {test_set}") - if params.bilingual: - test_cut = test_cut.filter(remove_short_utt) - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/zipformer/subsampling.py b/egs/multi_ja_en/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/multi_ja_en/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/test_scaling.py b/egs/multi_ja_en/ASR/zipformer/test_scaling.py deleted file mode 120000 index 715798436..000000000 --- a/egs/multi_ja_en/ASR/zipformer/test_scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/test_subsampling.py b/egs/multi_ja_en/ASR/zipformer/test_subsampling.py deleted file mode 120000 index bf0ee3d11..000000000 --- a/egs/multi_ja_en/ASR/zipformer/test_subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/tokenizer.py b/egs/multi_ja_en/ASR/zipformer/tokenizer.py deleted file mode 120000 index 958c99e85..000000000 --- a/egs/multi_ja_en/ASR/zipformer/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/train.py b/egs/multi_ja_en/ASR/zipformer/train.py deleted file mode 100755 index bfb037f50..000000000 --- a/egs/multi_ja_en/ASR/zipformer/train.py +++ /dev/null @@ -1,1462 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --bilingual 1 \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 600 - -# For streaming model training: -./zipformer/train.py \ - --bilingual 1 \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 600 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - -import argparse -import copy -import logging -import os -import re -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from multi_dataset import MultiDataset -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import byte_encode, diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, - tokenize_by_ja_char, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--bilingual", - type=str2bool, - default=False, - help="Whether the model is bilingual or not. 1 = bilingual.", - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - # changed - not used in monolingual streaming - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_2000/bbpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.015, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - sentencepiece_processor: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - if not params.bilingual: - y = tokenizer.encode(texts, out_type=int) - else: - y = sentencepiece_processor.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - sentencepiece_processor: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - tokenizer: Tokenizer, - sentencepiece_processor: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch( - batch, - params=params, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - ) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - # Use lang_dir for further operations - # tokenizer = Tokenizer.load(args.lang, args.lang_type) - - # sentencepiece_processor = spm.SentencePieceProcessor() - # sentencepiece_processor.load(params.bpe_model) - tokenizer = None - sentencepiece_processor = None - - # is defined in local/prepare_lang_char.py - - if not params.bilingual: - tokenizer = Tokenizer.load(args.lang, args.lang_type) - params.blank_id = tokenizer.piece_to_id("") - params.vocab_size = tokenizer.get_piece_size() - else: - sentencepiece_processor = spm.SentencePieceProcessor() - sentencepiece_processor.load(params.bpe_model) - params.blank_id = sentencepiece_processor.piece_to_id("") - params.vocab_size = sentencepiece_processor.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - if params.bilingual: - multi_dataset = MultiDataset(args) - train_cuts = multi_dataset.train_cuts() - else: - train_cuts = reazonspeech_corpus.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - # if c.duration < 1.0 or c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - # return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_samples - 7) // 2 + 1) // 2 - if not params.bilingual: - tokens = tokenizer.encode(c.supervisions[0].text, out_type=str) - else: - tokens = sentencepiece_processor.encode( - c.supervisions[0].text, out_type=str - ) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_samples}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - def tokenize_and_encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = byte_encode(tokenize_by_ja_char(text)) - c.supervisions[0].text = text - return c - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.bilingual: - train_cuts = train_cuts.map(tokenize_and_encode_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = reazonspeech_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - if params.bilingual: - valid_cuts = reazonspeech_corpus.valid_cuts() - else: - valid_cuts = multi_dataset.dev_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - tokenizer: Tokenizer, - sentencepiece_processor: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - tokenizer: - The BPE Tokenizer model. - sentencepiece_processor: - The BPE SentencePieceProcessor model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - if params.bilingual: - y = sentencepiece_processor.encode(supervisions["text"], out_type=int) - else: - y = tokenizer.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - tokenizer: Tokenizer, - sentencepiece_processor: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch( - batch, - params=params, - tokenizer=tokenizer, - sentencepiece_processor=sentencepiece_processor, - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/multi_ja_en/ASR/zipformer/zipformer.py b/egs/multi_ja_en/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/multi_ja_en/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/README.md b/egs/multi_zh-hans/ASR/README.md deleted file mode 100644 index 1e60c733c..000000000 --- a/egs/multi_zh-hans/ASR/README.md +++ /dev/null @@ -1,39 +0,0 @@ - -# Introduction - -This recipe includes scripts for training Zipformer model using multiple Chinese datasets. - -# Included Training Sets -1. THCHS-30 -2. AiShell-{1,2,4} -3. ST-CMDS -4. Primewords -5. MagicData -6. Aidatatang_200zh -7. AliMeeting -8. WeNetSpeech -9. KeSpeech-ASR - -|Datset| Number of hours| URL| -|---|---:|---| -|**TOTAL**|14,106|---| -|THCHS-30|35|https://www.openslr.org/18/| -|AiShell-1|170|https://www.openslr.org/33/| -|AiShell-2|1,000|http://www.aishelltech.com/aishell_2| -|AiShell-4|120|https://www.openslr.org/111/| -|ST-CMDS|110|https://www.openslr.org/38/| -|Primewords|99|https://www.openslr.org/47/| -|aidatatang_200zh|200|https://www.openslr.org/62/| -|MagicData|755|https://www.openslr.org/68/| -|AliMeeting|100|https://openslr.org/119/| -|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech| -|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech| - - -# Included Test Sets -1. Aishell-{1,2,4} -2. Aidatatang_200zh -3. AliMeeting -4. MagicData -5. KeSpeech-ASR -6. WeNetSpeech diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md deleted file mode 100644 index 622218d02..000000000 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ /dev/null @@ -1,233 +0,0 @@ -## Results - -### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2 -#### Whisper -[./whisper](./whisper) - -Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search. - -|Model| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | -|-|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| -| | Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | -|whisper-large-v2-ft |Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | -|whisper-large-v2-ft-distill |Greedy Search | 24.91 | 26.73 | 0.91 | 0.94 | 2.71 | 2.98 | 17.65 | 2.81 | 2.47 | 5.16 | 2.10 | 6.27 | 8.34 | - -Command for training is: -```bash -pip install -r whisper/requirements.txt - -# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54 - -torchrun --nproc-per-node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json -``` - -Command for decoding using fine-tuned models: -```bash -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper -ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch 999 --avg 1 \ - --beam-size 10 --max-duration 50 -``` - -Fine-tuned models, training logs, decoding logs, tensorboard and decoding results -are available at - - -### Multi Chinese datasets char-based training results (streaming) on zipformer-xl model - -#### Streaming (with CTC head) - -The training command for extra-large model (num of params : ~700M): - -Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features. - -``` -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 20 \ - --use-fp16 1 \ - --max-duration 1200 \ - --num-workers 8 \ - --use-ctc 1 \ - --exp-dir zipformer/exp-xl \ - --causal 1 \ - --num-encoder-layers 2,3,5,6,5,3 \ - --feedforward-dim 1536,2048,3072,4096,3072,1536 \ - --encoder-dim 512,768,1024,1536,1024,512 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --decoder-dim 768 --joiner-dim 768 \ - --value-head-dim 18 \ - --query-head-dim 48 \ - --num-heads 4,4,4,8,4,4 - -``` - -The decoding command for transducer greedy search: - -``` -./zipformer/decode.py \ - --epoch 999 \ - --avg 1 \ - --causal 1 \ - --use-averaged-model False \ - --chunk_size -1 - --left-context-frames -1 \ - --use-ctc 1 \ - --exp-dir zipformer/exp-xl \ - --max-duration 1200 \ - --num-encoder-layers 2,3,5,6,5,3 \ - --feedforward-dim 1536,2048,3072,4096,3072,1536 \ - --encoder-dim 512,768,1024,1536,1024,512 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --decoder-dim 768 --joiner-dim 768 \ - --value-head-dim 18 \ - --query-head-dim 48 \ - --num-heads 4,4,4,8,4,4 -``` - -Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled). - -| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | -|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| Transducer Greedy Offline | 21.67 | 23.43 | 1.22 | 1.31 | 3.17 | 3.27 | 14.64 | 2.42 | 1.99 | 5.00 | 2.29 | 5.98 | 5.15 | 5.85 | 6.89 | - -Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-xl -### Multi Chinese datasets char-based training results (streaming) on zipformer large model - -#### Streaming (with CTC head) - -The training command for large model (num of params : ~160M): - -Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features. - -``` -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 20 \ - --use-fp16 1 \ - --max-duration 1200 \ - --num-workers 8 \ - --use-ctc 1 \ - --exp-dir zipformer/exp-large \ - --causal 1 \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 768,1024,1536,2048,1536,768 \ - --encoder-dim 256,384,512,768,512,256 \ - --blank-penalty 0.7 \ - --encoder-unmasked-dim 192,192,256,320,256,192 - -``` - -The decoding command for transducer greedy search: - -``` -./zipformer/decode.py \ - --epoch 999 \ - --avg 1 \ - --causal 1 \ - --use-averaged-model False \ - --chunk_size -1 - --left-context-frames -1 \ - --use-ctc 1 \ - --exp-dir zipformer/exp-large \ - --max-duration 1200 \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 768,1024,1536,2048,1536,768 \ - --encoder-dim 256,384,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 -``` - -Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled). - -| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | -|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 | -| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 | -| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 | -| Transducer Greedy Streaming | 26.83|28.74 | 1.75 | 1.91 | 3.84 | 4.12 | 17.83 | 3.23 | 2.71 | 7.31 | 3.16 | 8.69 | 5.71 | 7.91 | 8.54 | - -Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large - -### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model - -This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. - -#### Non-streaming (with CTC head) - -Best results (num of params : ~69M): - -The training command: - -``` -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 20 \ - --use-fp16 1 \ - --max-duration 600 \ - --num-workers 8 \ - --use-ctc 1 -``` - -The decoding command: - -``` -./zipformer/decode.py \ - --epoch 20 \ - --avg 1 \ - --use-ctc 1 -``` - -Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled). - -| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | -|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 | -| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | - -Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ - -#### Non-streaming - -Best results (num of params : ~69M): - -The training command: - -``` -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 20 \ - --use-fp16 1 \ - --max-duration 600 \ - --num-workers 8 -``` - -The decoding command: - -``` -./zipformer/decode.py \ - --epoch 20 \ - --avg 1 -``` - -Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled). - -| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | -|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | - - -Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ diff --git a/egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py b/egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py deleted file mode 100755 index d078e5b98..000000000 --- a/egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 - -""" -This script takes `bpe.model` as input and generates a file `tokens.txt` -from it. - -Usage: -./bpe_model_to_tokens.py /path/to/input/bpe.model > tokens.txt -""" -import argparse - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "bpe_model", - type=str, - help="Path to the input bpe.model", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - for i in range(sp.vocab_size()): - print(sp.id_to_piece(i), i) - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/local/compile_lg.py b/egs/multi_zh-hans/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/multi_zh-hans/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py deleted file mode 100755 index 2bbe28560..000000000 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# Copyright 2023 Xiaomi Corp. (Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) - -from icefall.utils import str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - parser.add_argument( - "--speed-perturb", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - - return parser - - -def compute_fbank_kespeech_dev_test(args): - in_out_dir = Path("data/fbank/kespeech") - # number of workers in dataloader - num_workers = 42 - - # number of seconds in a batch - batch_duration = 600 - - subsets = ( - "dev_phase1", - "dev_phase2", - "test", - ) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - if args.whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) - ) - else: - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - - logging.info(f"device: {device}") - - for partition in subsets: - cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}_raw.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - if args.speed_perturb: - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{in_out_dir}/feats_{partition}", - num_workers=num_workers, - batch_duration=batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - compute_fbank_kespeech_dev_test(args) - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py deleted file mode 100755 index fe7f87337..000000000 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py +++ /dev/null @@ -1,214 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# Copyright 2023 Xiaomi Corp. (Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from datetime import datetime -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, - set_audio_duration_mismatch_tolerance, - set_caching_enabled, -) - -from icefall.utils import str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--training-subset", - type=str, - default="train_phase1", - choices=["train_phase1", "train_phase2"], - help="The training subset for computing fbank feature.", - ) - - parser.add_argument( - "--num-workers", - type=int, - default=20, - help="Number of dataloading workers used for reading the audio.", - ) - - parser.add_argument( - "--batch-duration", - type=float, - default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", - ) - - parser.add_argument( - "--num-splits", - type=int, - required=True, - help="The number of splits of the given subset", - ) - - parser.add_argument( - "--start", - type=int, - default=0, - help="Process pieces starting from this number (inclusive).", - ) - - parser.add_argument( - "--stop", - type=int, - default=-1, - help="Stop processing pieces until this number (exclusive).", - ) - - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - - parser.add_argument( - "--speed-perturb", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - - return parser - - -def compute_fbank_kespeech_splits(args): - subset = args.training_subset - subset = str(subset) - num_splits = args.num_splits - output_dir = f"data/fbank/kespeech/{subset}_split_{num_splits}" - output_dir = Path(output_dir) - assert output_dir.exists(), f"{output_dir} does not exist!" - - num_digits = len(str(num_splits)) - - start = args.start - stop = args.stop - if stop < start: - stop = num_splits - - stop = min(stop, num_splits) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - if args.whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") - ) - else: - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - logging.info(f"device: {device}") - - set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance - set_caching_enabled(False) - for i in range(start, stop): - idx = f"{i}".zfill(num_digits) - logging.info(f"Processing {i+1}/{num_splits}") - - cuts_path = output_dir / f"kespeech-asr_cuts_{subset}.{idx}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = output_dir / f"kespeech-asr_cuts_{subset}_raw.{idx}.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - if args.speed_perturb: - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/feats_{subset}_{idx}", - num_workers=args.num_workers, - batch_duration=args.batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -def main(): - now = datetime.now() - date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - - log_filename = "log-compute_fbank_kespeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - log_filename = f"{log_filename}-{date_time}" - - logging.basicConfig( - filename=log_filename, - format=formatter, - level=logging.INFO, - filemode="w", - ) - - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(logging.Formatter(formatter)) - logging.getLogger("").addHandler(console) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - compute_fbank_kespeech_splits(args) - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py deleted file mode 100755 index 192bffa9f..000000000 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the MagicData dataset. -It looks for manifests in the directory data/manifests/magicdata. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - return parser - - -def compute_fbank_magicdata( - num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False -): - src_dir = Path("data/manifests/magicdata") - output_dir = Path("data/fbank") - num_jobs = min(8, os.cpu_count()) - - dataset_parts = ("train", "test", "dev") - prefix = "magicdata" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - if args.whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and speed_perturb: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--speed-perturb", - type=bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_magicdata( - num_mel_bins=args.num_mel_bins, - speed_perturb=args.speed_perturb, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py deleted file mode 100755 index 019b10d24..000000000 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the Primewords dataset. -It looks for manifests in the directory data/manifests/primewords. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_primewords( - num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False -): - src_dir = Path("data/manifests/primewords") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ("train",) - prefix = "primewords" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and speed_perturb: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--speed-perturb", - type=bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_primewords( - num_mel_bins=args.num_mel_bins, - speed_perturb=args.speed_perturb, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py deleted file mode 100755 index f29ae5a46..000000000 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py +++ /dev/null @@ -1,143 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the ST-CMDS dataset. -It looks for manifests in the directory data/manifests/stcmds. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_stcmds( - num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False -): - src_dir = Path("data/manifests/stcmds") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ("train",) - prefix = "stcmds" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and speed_perturb: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--speed-perturb", - type=bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_stcmds( - num_mel_bins=args.num_mel_bins, - speed_perturb=args.speed_perturb, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py deleted file mode 100755 index 4ad78e0ba..000000000 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the THCHS-30 dataset. -It looks for manifests in the directory data/manifests/thchs30. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_thchs30( - num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False -): - src_dir = Path("data/manifests/thchs30") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "train", - "dev", - "test", - ) - prefix = "thchs_30" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - (cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)) - if speed_perturb - else cut_set - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--speed-perturb", - type=bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_thchs30( - num_mel_bins=args.num_mel_bins, - speed_perturb=args.speed_perturb, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/multi_zh-hans/ASR/local/prepare_char.py b/egs/multi_zh-hans/ASR/local/prepare_char.py deleted file mode 120000 index be7da61af..000000000 --- a/egs/multi_zh-hans/ASR/local/prepare_char.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py deleted file mode 100755 index 020800c15..000000000 --- a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script tokenizes the training transcript by CJK characters -# and saves the result to transcript_chars.txt, which is used -# to train the BPE model later. - -import argparse -from pathlib import Path - -from tqdm.auto import tqdm - -from icefall.utils import tokenize_by_CJK_char - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Output directory. - The generated transcript_chars.txt is saved to this directory. - """, - ) - - parser.add_argument( - "--text", - type=str, - help="WenetSpeech training transcript.", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - text = Path(args.text) - - assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" - - transcript_path = lang_dir / "transcript_chars.txt" - - with open(text, "r", encoding="utf-8") as fin: - with open(transcript_path, "w+", encoding="utf-8") as fout: - for line in fin: - fout.write(tokenize_by_CJK_char(line) + "\n") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/local/prepare_lang.py b/egs/multi_zh-hans/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/multi_zh-hans/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/prepare_lang_bpe.py b/egs/multi_zh-hans/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/multi_zh-hans/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py b/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py deleted file mode 100755 index 20274263f..000000000 --- a/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# Copyright 2023 Xiaomi Corp. (Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import re -from pathlib import Path - -from lhotse import CutSet, SupervisionSegment -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall import setup_logger - -# Similar text filtering and normalization procedure as in: -# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh - - -def normalize_text( - utt: str, - punct_pattern=re.compile(r"<(PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), - whitespace_pattern=re.compile(r"\s\s+"), -) -> str: - return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) - - -def has_no_oov( - sup: SupervisionSegment, - oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER|SPOKEN_NOISE)>"), -) -> bool: - return oov_pattern.search(sup.text) is None - - -def preprocess_kespeech(speed_perturb: bool = False): - src_dir = Path("data/manifests/kespeech") - output_dir = Path("data/fbank/kespeech") - output_dir.mkdir(exist_ok=True) - - # Note: By default, we preprocess all sub-parts. - # You can delete those that you don't need. - # For instance, if you don't want to use the test subpart, just remove - # the line below containing "test" - dataset_parts = ( - "dev_phase1", - "dev_phase2", - "test", - "train_phase1", - "train_phase2", - ) - - logging.info("Loading manifest (may take 10 minutes)") - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - suffix="jsonl.gz", - prefix="kespeech-asr", - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - logging_threshold = 50 - logging_count = 0 - - for partition, m in manifests.items(): - logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"kespeech-asr_cuts_{partition}_raw.jsonl.gz" - if raw_cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping") - continue - - # Note this step makes the recipe different than LibriSpeech: - # We must filter out some utterances and remove punctuation - # to be consistent with Kaldi. - logging.info("Filtering OOV utterances from supervisions") - m["supervisions"] = m["supervisions"].filter(has_no_oov) - logging.info(f"Normalizing text in {partition}") - for sup in m["supervisions"]: - orig_text = sup.text - sup.text = normalize_text(sup.text) - if logging_count < logging_threshold and len(orig_text) != len(sup.text): - logging_count += 1 - logging.info( - f"\nOriginal text vs normalized text:\n{orig_text}\n{sup.text}" - ) - - # Create long-recording cut manifests. - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - # Run data augmentation that needs to be done in the - # time domain. - if partition not in [ - "dev_phase1", - "dev_phase2", - "test", - ]: - if speed_perturb: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - logging.info(f"Saving to {raw_cuts_path}") - cut_set.to_file(raw_cuts_path) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--speed-perturb", - type=bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - return parser.parse_args() - - -def main(): - setup_logger(log_filename="./log-preprocess-kespeech") - - args = get_args() - preprocess_kespeech(speed_perturb=args.speed_perturb) - logging.info("Done") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/local/text2token.py b/egs/multi_zh-hans/ASR/local/text2token.py deleted file mode 120000 index ce5cfd537..000000000 --- a/egs/multi_zh-hans/ASR/local/text2token.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/train_bpe_model.py b/egs/multi_zh-hans/ASR/local/train_bpe_model.py deleted file mode 100755 index 976ea0ba8..000000000 --- a/egs/multi_zh-hans/ASR/local/train_bpe_model.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - parser.add_argument( - "--byte-fallback", - type=bool, - default=True, - help="Enable byte fallback for BPE model.", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 0.98 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - byte_fallback=args.byte_fallback, - ) - else: - print(f"{model_file} exists - skipping") - return - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh-hans/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/multi_zh-hans/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh deleted file mode 100755 index 3d2a9471c..000000000 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ /dev/null @@ -1,493 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 -num_splits=100 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -vocab_sizes=( - 2000 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -log "Dataset: musan" -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Soft link fbank of musan" - mkdir -p data/fbank - if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" - exit 1 - fi -fi - -log "Dataset: THCHS-30" -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare THCHS-30" - if [ ! -d $dl_dir/thchs30 ]; then - log "Downloading THCHS-30" - lhotse download thchs-30 $dl_dir/thchs30 - fi - - if [ ! -f data/manifests/.thchs30.done ]; then - mkdir -p data/manifests - lhotse prepare thchs-30 $dl_dir/thchs30 data/manifests/thchs30 - touch data/manifests/.thchs30.done - fi - - if [ ! -f data/fbank/.thchs30.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_thchs30.py --speed-perturb true - touch data/fbank/.thchs30.done - fi -fi - -log "Dataset: AISHELL-1" -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare AISHELL-1" - if [ -e ../../aishell/ASR/data/fbank/.aishell.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_feats_train) . - ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_feats_dev) . - ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_feats_test) . - ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_cuts_train.jsonl.gz) . - ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_cuts_dev.jsonl.gz) . - ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_cuts_test.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../aishell/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi -fi - -log "Dataset: AISHELL-2" -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare AISHELL-2" - if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_train) . - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_dev) . - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_test) . - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_train.jsonl.gz) . - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) . - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi -fi - -log "Dataset: AISHELL-4" -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare AISHELL-4" - if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_L) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_M) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train_S) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi -fi - -log "Dataset: ST-CMDS" -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare ST-CMDS" - if [ ! -f $dl_dir/stcmds/ST-CMDS-20170001_1-OS.tar.gz ]; then - log "Downloading ST-CMDS" - lhotse download stcmds $dl_dir/stcmds - fi - - if [ ! -f data/manifests/.stcmds.done ]; then - mkdir -p data/manifests - lhotse prepare stcmds $dl_dir/stcmds data/manifests/stcmds - touch data/manifests/.stcmds.done - fi - - if [ ! -f data/fbank/.stcmds.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_stcmds.py --speed-perturb true - touch data/fbank/.stcmds.done - fi -fi - - -log "Dataset: Primewords" -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare Primewords" - if [ ! -f $dl_dir/primewords/primewords_md_2018_set1.tar.gz ]; then - log "Downloading Primewords" - lhotse download primewords $dl_dir/primewords - fi - - if [ ! -f data/manifests/.primewords.done ]; then - mkdir -p data/manifests - lhotse prepare primewords $dl_dir/primewords data/manifests/primewords - touch data/manifests/.primewords.done - fi - - if [ ! -f data/fbank/.primewords.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_primewords.py --speed-perturb true - touch data/fbank/.primewords.done - fi -fi - -log "Dataset: MagicData" -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare MagicData" - if [ ! -f $dl_dir/magicdata/train_set.tar.gz ]; then - log "Downloading MagicData" - lhotse download magicdata $dl_dir/magicdata - fi - - if [ ! -f data/manifests/.magicdata.done ]; then - mkdir -p data/manifests - lhotse prepare magicdata $dl_dir/magicdata data/manifests/magicdata - touch data/manifests/.magicdata.done - fi - - if [ ! -f data/fbank/.magicdata.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_magicdata.py --speed-perturb true - touch data/fbank/.magicdata.done - fi -fi - -log "Dataset: aidatatang_200zh" -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Prepare aidatatang_200zh" - if [ -e ../../aidatatang_200zh/ASR/data/fbank/.aidatatang_200zh.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_feats_train) . - ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_feats_dev) . - ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_feats_test) . - ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_cuts_train.jsonl.gz) . - ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_cuts_dev.jsonl.gz) . - ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_cuts_test.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../aidatatang_200zh/ASR/prepare.sh --stage 4 --stop-stage 4" - exit 1 - fi -fi - -log "Dataset: Ali-Meeting" -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Prepare Ali-Meeting" - if [ -e ../../alimeeting/ASR/data/fbank/.fbank.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_feats_train) . - ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_feats_eval) . - ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_feats_test) . - ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_cuts_train.jsonl.gz) . - ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_cuts_eval.jsonl.gz) . - ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_cuts_test.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../alimeeting/ASR/prepare.sh --stage 5 --stop-stage 5" - exit 1 - fi -fi - -log "Dataset: WenetSpeech" -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Prepare WenetSpeech" - if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then - cd data/fbank - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . - - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_${num_splits}) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/*.lca) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/) ./wenetspeech - cd ../.. - else - log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" - exit 1 - fi - - if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then - cd data - cp -r ../../../../wenetspeech/ASR/data/lang_char . - cd .. - else - log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" - exit 1 - fi -fi - -log "Dataset: KeSpeech" -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Prepare KeSpeech" - if [ ! -d $dl_dir/KeSpeech ]; then - log "Abort! Please download KeSpeech first." - log "KeSpeech download link: https://github.com/KeSpeech/KeSpeech" - exit 1 - fi - - if [ ! -f data/manifests/.kespeech.done ]; then - mkdir -p data/manifests - lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech - touch data/manifests/.kespeech.done - fi - - if [ ! -f data/fbank/.kespeech.done ]; then - mkdir -p data/fbank - - log "Preprocess KeSpeech manifest" - if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then - python3 ./local/preprocess_kespeech.py - touch data/fbank/.kespeech_preprocess_complete - fi - - if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then - log "Spliting KeSpeech train_phase1" - lhotse split ${num_splits} \ - data/fbank/kespeech/kespeech-asr_cuts_train_phase1_raw.jsonl.gz \ - data/fbank/kespeech/train_phase1_split_${num_splits} - touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done - fi - - if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then - log "Spliting KeSpeech train_phase2" - lhotse split ${num_splits} \ - data/fbank/kespeech/kespeech-asr_cuts_train_phase2_raw.jsonl.gz \ - data/fbank/kespeech/train_phase2_split_${num_splits} - touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done - fi - - log "Compute KeSpeech fbank for train_phase1" - ./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase1 - - log "Compute KeSpeech fbank for train_phase2" - ./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase2 - - log "Compute KeSpeech fbank for test/dev" - ./local/compute_fbank_kespeech_dev_test.py - - if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then - pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz") - lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz - fi - if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then - pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz") - lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz - fi - - touch data/fbank/.kespeech.done - fi -fi - -whisper_mel_bins=80 -if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then - log "Stage 120: Prepare KeSpeech for whisper" - if [ ! -d $dl_dir/KeSpeech ]; then - log "Abort! Please download KeSpeech first." - log "KeSpeech download link: https://github.com/KeSpeech/KeSpeech" - exit 1 - fi - - if [ ! -f data/manifests/.kespeech.done ]; then - mkdir -p data/manifests - lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech - touch data/manifests/.kespeech.done - fi - - if [ ! -f data/fbank/.kespeech.done ]; then - mkdir -p data/fbank - - log "Preprocess KeSpeech manifest" - if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then - python3 ./local/preprocess_kespeech.py --speed-perturb true - touch data/fbank/.kespeech_preprocess_complete - fi - - if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then - log "Spliting KeSpeech train_phase1" - lhotse split ${num_splits} \ - data/fbank/kespeech/kespeech-asr_cuts_train_phase1_raw.jsonl.gz \ - data/fbank/kespeech/train_phase1_split_${num_splits} - touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done - fi - - if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then - log "Spliting KeSpeech train_phase2" - lhotse split ${num_splits} \ - data/fbank/kespeech/kespeech-asr_cuts_train_phase2_raw.jsonl.gz \ - data/fbank/kespeech/train_phase2_split_${num_splits} - touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done - fi - - log "Compute KeSpeech fbank for train_phase1" - ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - - log "Compute KeSpeech fbank for train_phase2" - ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase2 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - - log "Compute KeSpeech fbank for test/dev" - # ./local/compute_fbank_kespeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - - if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then - pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz") - lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz - fi - if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then - pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz") - lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz - fi - touch data/fbank/.kespeech.done - fi -fi - -if [ $stage -le 121 ] && [ $stop_stage -ge 121 ]; then - log "Stage 121: Prepare MagicData, Primewords, ST-CMDS, THCHS-30 for whisper" - - if [ ! -f data/manifests/.magicdata.done ]; then - mkdir -p data/manifests - lhotse prepare magicdata $dl_dir/magicdata data/manifests/magicdata - touch data/manifests/.magicdata.done - fi - - if [ ! -f data/manifests/.primewords.done ]; then - mkdir -p data/manifests - lhotse prepare primewords $dl_dir/primewords data/manifests/primewords - touch data/manifests/.primewords.done - fi - if [ ! -f data/manifests/.stcmds.done ]; then - mkdir -p data/manifests - lhotse prepare stcmds $dl_dir/stcmds data/manifests/stcmds - touch data/manifests/.stcmds.done - fi - - if [ ! -f data/manifests/.thchs30.done ]; then - mkdir -p data/manifests - lhotse prepare thchs-30 $dl_dir/thchs30 data/manifests/thchs30 - touch data/manifests/.thchs30.done - fi - - if [ ! -f data/fbank/.thchs30.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_thchs30.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.thchs30.done - fi - - if [ ! -f data/fbank/.stcmds.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_stcmds.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.stcmds.done - fi - if [ ! -f data/fbank/.magicdata.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_magicdata.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.magicdata.done - fi - - if [ ! -f data/fbank/.primewords.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_primewords.py --speed-perturb true --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.primewords.done - fi - -fi - - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: BPE model training (note that we use transcripts of wenetspeech only for BPE training)" - ./local/prepare_for_bpe_model.py --lang-dir ./data/lang_char --text ./data/lang_char/text - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - mkdir -p $lang_dir - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --transcript ./data/lang_char/transcript_chars.txt \ - --vocab-size $vocab_size - - ./local/bpe_model_to_tokens.py $lang_dir/bpe.model > $lang_dir/tokens.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - cp data/lang_char/words.txt $lang_dir - - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done -fi - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)" - - if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then - cd data - ln -s ../../../../wenetspeech/ASR/data/lm . - cd .. - else - log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" - exit 1 - fi -fi - -if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then - log "Stage 15: Compile LG" - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - python ./local/compile_lg.py --lang-dir $lang_dir - done -fi diff --git a/egs/multi_zh-hans/ASR/shared b/egs/multi_zh-hans/ASR/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/multi_zh-hans/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/asr_datamodule.py b/egs/multi_zh-hans/ASR/whisper/asr_datamodule.py deleted file mode 120000 index 3c8b7f2d4..000000000 --- a/egs/multi_zh-hans/ASR/whisper/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../zipformer/asr_datamodule.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py deleted file mode 100755 index f758f546c..000000000 --- a/egs/multi_zh-hans/ASR/whisper/decode.py +++ /dev/null @@ -1,567 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Wei Kang) -# 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -# Command for decoding using fine-tuned models: -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper -ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch 999 --avg 1 \ - --beam-size 10 --max-duration 50 - -# Command for decoding using pretrained models (before fine-tuning): - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch -1 --avg 1 \ - --remove-whisper-encoder-input-length-restriction False \ - --beam-size 10 --max-duration 50 - -""" - -import argparse -import logging -import re -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -import whisper -from asr_datamodule import AsrDataModule -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from tn.chinese.normalizer import Normalizer -from whisper.normalizers import BasicTextNormalizer -from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from zhconv import convert - -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def average_checkpoints( - filenames: List[Path], device: torch.device = torch.device("cpu") -) -> dict: - """Average a list of checkpoints. - The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. - - Args: - filenames: - Filenames of the checkpoints to be averaged. We assume all - checkpoints are saved by :func:`save_checkpoint`. - device: - Move checkpoints to this device before averaging. - Returns: - Return a dict (i.e., state_dict) which is the average of all - model state dicts contained in the checkpoints. - """ - n = len(filenames) - - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] - else: - avg = torch.load(filenames[0], map_location=device) - - # Identify shared parameters. Two parameters are said to be shared - # if they have the same data_ptr - uniqued: Dict[int, str] = dict() - - for k, v in avg.items(): - v_data_ptr = v.data_ptr() - if v_data_ptr in uniqued: - continue - uniqued[v_data_ptr] = k - - uniqued_names = list(uniqued.values()) - - for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] - else: - state_dict = torch.load(filenames[i], map_location=device) - for k in uniqued_names: - avg[k] += state_dict[k] - - for k in uniqued_names: - if avg[k].is_floating_point(): - avg[k] /= n - else: - avg[k] //= n - - return avg - - -def remove_punctuation(text: str or List[str]): - """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py - - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings without any punctuation. - """ - punctuation = "!,.;:?、!,。;:?《》 " - if isinstance(text, str): - text = re.sub(r"[{}]+".format(punctuation), "", text).strip() - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = re.sub(r"[{}]+".format(punctuation), "", t).strip() - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type {type(text)}") - - -def to_simple(text: str or List[str]): - """Convert traditional Chinese to simplified Chinese. - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings converted to simplified Chinese. - """ - if isinstance(text, str): - text = convert(text, "zh-cn") - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = convert(t, "zh-cn") - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type{type(text)}") - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=-1, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="beam-search", - help="""Decoding method. - Supported values are: - - beam-search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=1, - help="beam size for beam search decoding", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="whisper/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], - help="""The model name to use. - """, - ) - - parser.add_argument( - "--remove-whisper-encoder-input-length-restriction", - type=str2bool, - default=True, - help="replace whisper encoder forward method to remove input length restriction", - ) - - parser.add_argument( - "--use-distill-whisper", - type=str2bool, - default=False, - help="Whether to use architecture of distill whisper.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: "beam-search" - - value: A list of lists. Each sublist is a list of token IDs. - Args: - params: - It is returned by :func:`get_params`. - model: - The neural model. - batch: - It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. - Returns: - Return a dict, whose key may be "beam-search". - """ - dtype = torch.float16 - device = torch.device("cuda") - - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device, dtype=dtype).transpose(1, 2) - if not params.remove_whisper_encoder_input_length_restriction: - T = 3000 - if feature.shape[2] < T: - feature = torch.cat( - [ - feature, - torch.zeros( - feature.shape[0], feature.shape[1], T - feature.shape[2] - ).to(device, dtype=dtype), - ], - 2, - ) - - supervisions = batch["supervisions"] - feature_len = supervisions["num_frames"] - feature_len = feature_len.to(device, dtype=dtype) - results = model.decode(feature, params.decoding_options) - hyps = [result.text for result in results] - - hyps = remove_punctuation(hyps) - hyps = to_simple(hyps) - hyps = [params.normalizer.normalize(hyp) for hyp in hyps] - print(hyps) - return {"beam-search": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - The dataloader. - params: - It is returned by :func:`get_params`. - model: - The neural model. - Returns: - Return a dict, whose key may be "beam-search". - """ - - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - batch=batch, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = normalize_text_alimeeting(ref_text) - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" - ) - - options = whisper.DecodingOptions( - task="transcribe", - language="zh", - without_timestamps=True, - beam_size=params.beam_size, - ) - params.decoding_options = options - params.cleaner = BasicTextNormalizer() - params.normalizer = Normalizer() - - logging.info("Decoding started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda") - - logging.info(f"device: {device}") - - if params.remove_whisper_encoder_input_length_restriction: - replace_whisper_encoder_forward() - if params.use_distill_whisper: - replace_whisper_decoder_forward() - model = whisper.load_model(params.model_name, "cpu") - if params.epoch > 0: - if params.avg > 1: - start = params.epoch - params.avg - assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - # deepspeed converted checkpoint only contains model state_dict - filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" - for epoch in range(start, params.epoch + 1) - ] - model.load_state_dict(average_checkpoints(filenames)) - else: - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - # save checkpoints - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(model.state_dict(), filename) - else: - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - model.load_state_dict(checkpoint, strict=True) - else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir) - - def remove_long_utt(c: Cut): - # Keep only utterances with duration in 30 seconds - # - if c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - test_sets_cuts = multi_dataset.test_cuts() - - test_sets = test_sets_cuts.keys() - test_dls = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json b/egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json deleted file mode 120000 index af7162d6c..000000000 --- a/egs/multi_zh-hans/ASR/whisper/ds_config_zero1.json +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/ds_config_zero1.json \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/label_smoothing.py b/egs/multi_zh-hans/ASR/whisper/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/multi_zh-hans/ASR/whisper/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py b/egs/multi_zh-hans/ASR/whisper/multi_dataset.py deleted file mode 120000 index d2e14a1ad..000000000 --- a/egs/multi_zh-hans/ASR/whisper/multi_dataset.py +++ /dev/null @@ -1 +0,0 @@ -../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/optim.py b/egs/multi_zh-hans/ASR/whisper/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/multi_zh-hans/ASR/whisper/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/requirements.txt b/egs/multi_zh-hans/ASR/whisper/requirements.txt deleted file mode 120000 index 744bf8bb6..000000000 --- a/egs/multi_zh-hans/ASR/whisper/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/requirements.txt \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py deleted file mode 100755 index fe2d950c1..000000000 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ /dev/null @@ -1,1029 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -#fine-tuning with deepspeed zero stage 1 -torchrun --nproc-per-node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json - -# fine-tuning with ddp -torchrun --nproc_per_node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_medium \ - --base-lr 1e-5 \ - --model-name medium -""" - -import argparse -import copy -import logging -import os -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import deepspeed -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import whisper -from asr_datamodule import AsrDataModule -from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict -from label_smoothing import LabelSmoothingLoss -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from multi_dataset import MultiDataset -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.functional import pad as pad_tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import update_averaged_model -from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=10, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="whisper/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], - help="""The model name to use. - """, - ) - - parser.add_argument( - "--pretrained-model-path", - type=str, - default=None, - help="""The path to the pretrained model if it is not None. Training will - start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=1e-5, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--use-distill-whisper", - type=str2bool, - default=False, - help="Whether to use architecture of distill whisper.", - ) - - parser = deepspeed.add_config_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - frame_shift_ms: The frame shift in milliseconds. - - allowed_excess_duration_ratio: The allowed excess duration ratio. - - best_train_loss: The best training loss so far. - - best_valid_loss: The best validation loss so far. - - best_train_epoch: The epoch where the best training loss is achieved. - - best_valid_epoch: The epoch where the best validation loss is achieved. - - batch_idx_train: The batch index of the current batch. - - log_interval: Log training stats every `log_interval` batches. - - reset_interval: Reset the stats every `reset_interval` batches. - - valid_interval: Run validation every `valid_interval` batches. - - env_info: The environment information. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "subsampling_factor": 2, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 10000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute the loss for the given batch. - Args: - params: - It is returned by :func:`get_params`. - tokenizer: - The tokenizer used to encode the text. - model: - The model for training. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - Whether it is training. - Returns: - Return a tuple of two elements. The first element is the loss tensor. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - - def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: - padding_size = max(tensor.shape[0] for tensor in tensors) - dims = len(tensors[0].shape) - padded_tensors = [] - for tensor in tensors: - padding = [0] * 2 * dims - padding[-1] = padding_size - tensor.shape[0] - padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) - return torch.stack([tensor for tensor in padded_tensors], dim=0) - - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - - assert feature.ndim == 3 - feature = feature.to(device) - feature = feature.transpose(1, 2) # (N, C, T) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - - texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [normalize_text_alimeeting(text) for text in texts] - - text_tokens_list = [ - list(tokenizer.sot_sequence_including_notimestamps) - + tokenizer.encode(text) - + [tokenizer.eot] - for text in texts - ] - # convert it to torch tensor - text_tokens_list = [ - torch.LongTensor(text_tokens) for text_tokens in text_tokens_list - ] - - # 50256 is the index of for all whisper models - prev_outputs_tokens = _batch_tensors( - [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 - ) - target_tokens = _batch_tensors( - [tokens[1:] for tokens in text_tokens_list], pad_value=50256 - ) - target_lengths = torch.LongTensor( - [tokens.shape[0] - 1 for tokens in text_tokens_list] - ) - - decoder_criterion = LabelSmoothingLoss( - ignore_index=50256, label_smoothing=0.1, reduction="sum" - ) - - # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> - ignore_prefix_size = 3 - with torch.set_grad_enabled(is_training): - encoder_out = model.encoder(feature) - text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) - text_logits = text_logits[:, ignore_prefix_size:, :] - target_tokens = target_tokens[:, ignore_prefix_size:] - loss = decoder_criterion(text_logits, target_tokens.to(device)) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - tokenizer=tokenizer, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - if params.deepspeed: - model.save_checkpoint( - save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - client_state={}, - ) - if rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - ) - os.system( - f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" - ) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - if params.deepspeed: - # deepspeed's backward() is different from torch's backward() - # in that it does not accept a loss tensor as input. - # It computes the loss internally. - model.backward(loss) - model.step() - else: - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - and not params.deepspeed - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - if batch_idx % params.log_interval == 0: - try: - cur_lr = scheduler.get_last_lr()[0] - except: # noqa - cur_lr = 0.0 - cur_grad_scale = ( - scaler._scale.item() - if (params.use_fp16 and not params.deepspeed) - else 1.0 - ) - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if (params.use_fp16 and not params.deepspeed) - else "" - ) - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info(params) - - logging.info("About to create model") - - replace_whisper_encoder_forward() - if params.use_distill_whisper: - replace_whisper_decoder_forward() - model = whisper.load_model(params.model_name, "cpu") - del model.alignment_heads - - if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") - if "model" not in checkpoint: - model.load_state_dict(checkpoint, strict=True) - else: - load_checkpoint(params.pretrained_model_path, model) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - tokenizer = whisper.tokenizer.get_tokenizer( - model.is_multilingual, - num_languages=model.num_languages, - language="zh", - task="transcribe", - ) - - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - else: - device = torch.device("cpu") - logging.info(f"Device: {device}") - model.to(device) - - optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if world_size > 1: - if params.deepspeed: - logging.info("Using DeepSpeed") - model, optimizer, _, scheduler = deepspeed.initialize( - args=params, model=model, model_parameters=model.parameters() - ) - else: - logging.info("Using DDP") - setup_dist(use_ddp_launch=True) - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - train_cuts = multi_dataset.train_cuts() - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_dl = data_module.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = multi_dataset.dev_cuts() - valid_dl = data_module.valid_dataloaders(valid_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - logging.info(f"start training from epoch {params.start_epoch}") - for epoch in range(params.start_epoch, params.num_epochs + 1): - if not params.deepspeed: - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - tokenizer=tokenizer, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if params.deepspeed: - model.save_checkpoint( - save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", - client_state={}, - ) - if rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", - ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") - else: - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1 and not params.deepspeed: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = get_world_size() - rank = get_rank() - - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - run(rank=rank, world_size=world_size, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py b/egs/multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py deleted file mode 100644 index c013426d4..000000000 --- a/egs/multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Dict, Iterable, Optional - -import numpy as np -import torch -import torch.nn.functional as F -import whisper -from torch import Tensor, nn -from whisper.model import LayerNorm, ResidualAttentionBlock - - -def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): - """ - x : torch.LongTensor, shape = (batch_size, <= n_ctx) - the text tokens - xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) - the encoded audio features to be attended on - """ - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(x) - + self.positional_embedding[offset : offset + x.shape[-1]] - ) - x = x + self.positional_embedding[offset : offset + x.shape[1]] - x = x.to(xa.dtype) - - # for block in self.blocks: - # x = block(x, xa, mask=self.mask, kv_cache=kv_cache) - # use architecture from the distill whisper model - # see https://github.com/huggingface/distil-whisper - x = self.blocks[0](x, xa, mask=self.mask, kv_cache=kv_cache) - x = self.blocks[-1](x, xa, mask=self.mask, kv_cache=kv_cache) - - x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() - - return logits - - -def replace_whisper_decoder_forward(): - """ - This function monkey patches the forward method of the whisper encoder. - To be called before the model is loaded, it changes whisper to process audio with any length < 30s. - """ - whisper.model.TextDecoder.forward = forward diff --git a/egs/multi_zh-hans/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/multi_zh-hans/ASR/whisper/whisper_encoder_forward_monkey_patch.py deleted file mode 120000 index 2a7808921..000000000 --- a/egs/multi_zh-hans/ASR/whisper/whisper_encoder_forward_monkey_patch.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py deleted file mode 100644 index 341579acb..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1,390 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=300.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl diff --git a/egs/multi_zh-hans/ASR/zipformer/beam_search.py b/egs/multi_zh-hans/ASR/zipformer/beam_search.py deleted file mode 120000 index 8e2c0a65c..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py deleted file mode 100755 index 8d4a81fb0..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py +++ /dev/null @@ -1,623 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Liyong Guo, -# Quandong Wang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) ctc-decoding -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method ctc-decoding - -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AsrDataModule -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import get_lattice, one_best_decoding -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_2000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="ctc-decoding", - help="""Decoding method. - Supported values are: - - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_decoding_params() -> AttributeDict: - """Parameters for decoding.""" - params = AttributeDict( - { - "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. - - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. - - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. - - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - torch.div( - supervisions["start_frame"], - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - supervisions["num_frames"], - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) - - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.decoding_method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - word_table: - It is the word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - H=H, - bpe_model=bpe_model, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = normalize_text_alimeeting(ref_text) - hyp_text = "".join(hyp_words) - this_batch.append((cut_id, ref_text, hyp_text)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - assert params.decoding_method in ("ctc-decoding",) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - params.vocab_size = num_classes - # and are defined in local/train_bpe_model.py - params.blank_id = 0 - - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=True, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - - G = None - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir) - - test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()} - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets = test_sets_cuts.keys() - test_dl = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dl): - logging.info(f"Start decoding test set: {test_set}") - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py deleted file mode 100755 index a1d018cd2..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ /dev/null @@ -1,843 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_2000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = normalize_text_alimeeting(ref_text) - hyp_text = "".join(hyp_words) - this_batch.append((cut_id, ref_text, hyp_text)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - params.suffix += f"-blank-penalty-{params.blank_penalty}" - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()} - - test_sets = test_sets_cuts.keys() - test_dl = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dl): - logging.info(f"Start decoding test set: {test_set}") - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/zipformer/decoder.py b/egs/multi_zh-hans/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py b/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py deleted file mode 120000 index 652346001..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export.py b/egs/multi_zh-hans/ASR/zipformer/export.py deleted file mode 100755 index 723288191..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/export.py +++ /dev/null @@ -1,541 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -(1) Export to torchscript model using torch.jit.script() - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 \ - --jit 1 - -It will generate a file `jit_script.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("jit_script.pt")`. - -Check ./jit_pretrained.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens data/lang_bpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 \ - --jit 1 - -It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. -You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. - -Check ./jit_pretrained_streaming.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -- For non-streaming model: - -To use the generated file with `zipformer/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_2000/bpe.model - -- For streaming model: - -To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - - # simulated streaming decoding - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_2000/bpe.model - - # chunk-wise streaming decoding - ./zipformer/streaming_decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_2000/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -- non-streaming model: -https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ - # You will find the pre-trained models in exp dir -""" - -import argparse -import logging -import re -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor, nn -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, str2bool - - -def num_tokens( - token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") -) -> int: - """Return the number of tokens excluding those from - disambiguation symbols. - - Caution: - 0 is not a token ID so it is excluded from the return value. - """ - symbols = token_table.symbols - ans = [] - for s in symbols: - if not disambig_pattern.match(s): - ans.append(token_table[s]) - num_tokens = len(ans) - if 0 in ans: - num_tokens -= 1 - return num_tokens - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_2000/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named jit_script.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class EncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Args: - features: (N, T, C) - feature_lengths: (N,) - """ - x, x_lens = self.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -class StreamingEncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """Streaming forward for encoder_embed and encoder. - - Args: - features: (N, T, C) - feature_lengths: (N,) - states: a list of Tensors - - Returns encoder outputs, output lengths, and updated states. - """ - chunk_size = self.chunk_size - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lengths, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - - # Wrap encoder and encoder_embed as a module - if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) - chunk_size = model.encoder.chunk_size - left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" - else: - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - filename = "jit_script.pt" - - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - model.save(str(params.exp_dir / filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py deleted file mode 100755 index 68111fad7..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) use the checkpoint exp_dir/epoch-xxx.pt -./zipformer/generate_averaged_model.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp - -It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. - -(2) use the checkpoint exp_dir/checkpoint-iter.pt -./zipformer/generate_averaged_model.py \ - --iter 22000 \ - --avg 5 \ - --exp-dir ./zipformer/exp - -It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. -""" - - -import argparse -from pathlib import Path - -import k2 -import torch -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - print("Script started") - - device = torch.device("cpu") - print(f"Device: {device}") - - symbol_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = symbol_table[""] - params.unk_id = symbol_table[""] - params.vocab_size = len(symbol_table) - - print("About to create model") - model = get_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_ctc.py deleted file mode 120000 index 9a8da5844..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 120000 index 1962351e9..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/joiner.py b/egs/multi_zh-hans/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/model.py b/egs/multi_zh-hans/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py deleted file mode 120000 index d2e14a1ad..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ /dev/null @@ -1 +0,0 @@ -../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_check.py b/egs/multi_zh-hans/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py deleted file mode 120000 index 0573b88c5..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py deleted file mode 120000 index d623a8462..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 120000 index cfea104c2..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/optim.py b/egs/multi_zh-hans/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py deleted file mode 100755 index c15db11f7..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_2000/tokens.txt \ - --epoch 23 \ - --avg 1 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_2000/tokens.txt \ - --epoch 23 \ - --avg 1 - -Usage of this script: - -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_2000/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_2000/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_2000/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_2000/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_2000/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_2000/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - fast_beam_search_one_best, - greedy_search_batch, - modified_beam_search, -) -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.utils import make_pad_mask - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling.py b/egs/multi_zh-hans/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py b/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py deleted file mode 120000 index 13fd02a78..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/subsampling.py b/egs/multi_zh-hans/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py deleted file mode 100755 index 3dbfc48eb..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ /dev/null @@ -1,1425 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from multi_dataset import MultiDataset -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [normalize_text_alimeeting(text) for text in texts] - - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir) - - train_cuts = multi_dataset.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = data_module.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = multi_dataset.dev_cuts() - valid_dl = data_module.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/zipformer/zipformer.py b/egs/multi_zh-hans/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/README.md b/egs/multi_zh_en/ASR/README.md deleted file mode 100644 index 29341571d..000000000 --- a/egs/multi_zh_en/ASR/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Introduction - -This recipe includes scripts for training Zipformer model using both English and Chinese datasets. - -# Included Training Sets - -1. LibriSpeech (English) -2. AiShell-2 (Chinese) -3. TAL-CSASR (Code-Switching, Chinese and English) - -|Datset| Number of hours| URL| -|---|---:|---| -|**TOTAL**|2,547|---| -|LibriSpeech|960|https://www.openslr.org/12/| -|AiShell-2|1,000|http://www.aishelltech.com/aishell_2| -|TAL-CSASR|587|https://ai.100tal.com/openData/voice| - - - diff --git a/egs/multi_zh_en/ASR/RESULTS.md b/egs/multi_zh_en/ASR/RESULTS.md deleted file mode 100644 index 3562d6ac3..000000000 --- a/egs/multi_zh_en/ASR/RESULTS.md +++ /dev/null @@ -1,44 +0,0 @@ -## Results - -### Zh-En datasets bpe-based training results (Non-streaming) on Zipformer model - -This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1265) in icefall. - -#### Non-streaming (Byte-Level BPE vocab_size=2000) - -Best results (num of params : ~69M): - -The training command: - -``` -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 35 \ - --use-fp16 1 \ - --max-duration 1000 \ - --num-workers 8 -``` - -The decoding command: - -``` -for method in greedy_search modified_beam_search fast_beam_search; do - ./zipformer/decode.py \ - --epoch 34 \ - --avg 19 \ - --decoding-method $method -done -``` - -Word Error Rates (WERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model (# tokens is 2000). - -| Datasets | TAL-CSASR | TAL-CSASR | AiShell-2 | AiShell-2 | LibriSpeech | LibriSpeech | -|----------------------|-----------|-----------|-----------|-----------|-------------|-------------| -| Zipformer WER (%) | dev | test | dev | test | test-clean | test-other | -| greedy_search | 6.65 | 6.69 | 6.57 | 7.03 | 2.43 | 5.70 | -| modified_beam_search | 6.46 | 6.51 | 6.18 | 6.60 | 2.41 | 5.57 | -| fast_beam_search | 6.57 | 6.68 | 6.40 | 6.74 | 2.40 | 5.56 | - -Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22, which is trained on LibriSpeech 960-hour training set (with speed perturbation), TAL-CSASR training set (with speed perturbation) and AiShell-2 (w/o speed perturbation). - - diff --git a/egs/multi_zh_en/ASR/local/compile_lg.py b/egs/multi_zh_en/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/multi_zh_en/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_char.py b/egs/multi_zh_en/ASR/local/prepare_char.py deleted file mode 120000 index 42743b544..000000000 --- a/egs/multi_zh_en/ASR/local/prepare_char.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py deleted file mode 100755 index 00514e6bb..000000000 --- a/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script tokenizes the training transcript by CJK characters -# and saves the result to transcript_chars.txt, which is used -# to train the BPE model later. - -import argparse -from pathlib import Path - -from tqdm.auto import tqdm - -from icefall.utils import tokenize_by_CJK_char - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Output directory. - The generated transcript_chars.txt is saved to this directory. - """, - ) - - parser.add_argument( - "--text", - type=str, - help="Training transcript.", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - text = Path(args.text) - - assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" - - transcript_path = lang_dir / "transcript_chars.txt" - - with open(text, "r", encoding="utf-8") as fin: - with open(transcript_path, "w+", encoding="utf-8") as fout: - for line in tqdm(fin): - fout.write(tokenize_by_CJK_char(line) + "\n") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh_en/ASR/local/prepare_lang.py b/egs/multi_zh_en/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/multi_zh_en/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py deleted file mode 120000 index 9a0b44642..000000000 --- a/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_lang_bbpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_words.py b/egs/multi_zh_en/ASR/local/prepare_words.py deleted file mode 120000 index ef2b4eaf3..000000000 --- a/egs/multi_zh_en/ASR/local/prepare_words.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell2/ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2segments.py b/egs/multi_zh_en/ASR/local/text2segments.py deleted file mode 120000 index 7d68a39c3..000000000 --- a/egs/multi_zh_en/ASR/local/text2segments.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech/ASR/local/text2segments.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2token.py b/egs/multi_zh_en/ASR/local/text2token.py deleted file mode 120000 index ce5cfd537..000000000 --- a/egs/multi_zh_en/ASR/local/text2token.py +++ /dev/null @@ -1 +0,0 @@ -../../../wenetspeech/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/train_bbpe_model.py b/egs/multi_zh_en/ASR/local/train_bbpe_model.py deleted file mode 120000 index 7fb4a9f9d..000000000 --- a/egs/multi_zh_en/ASR/local/train_bbpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/train_bbpe_model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh deleted file mode 100755 index a1530be29..000000000 --- a/egs/multi_zh_en/ASR/prepare.sh +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -vocab_sizes=( - 2000 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -log "Dataset: musan" -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Soft link fbank of musan" - mkdir -p data/fbank - if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . - cd ../.. - else - log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" - exit 1 - fi -fi - -log "Dataset: LibriSpeech" -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Soft link fbank of LibriSpeech" - mkdir -p data/fbank - if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) . - ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) . - cd ../.. - else - log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi -fi - -log "Dataset: AiShell-2" -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Soft link fbank of AiShell-2" - mkdir -p data/fbank - if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then - cd data/fbank - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts*) . - ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats*) . - cd ../.. - else - log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare Byte BPE based lang" - mkdir -p data/fbank - if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then - log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" - exit 1 - fi - - if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then - log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6" - exit 1 - fi - - cd data/ - if [ ! -d ./lang_char ]; then - ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) . - fi - if [ ! -d ./lang_bpe_500 ]; then - ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . - fi - cd ../ - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bbpe_${vocab_size} - mkdir -p $lang_dir - - cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ - > $lang_dir/text - - if [ ! -f $lang_dir/transcript_chars.txt ]; then - ./local/prepare_for_bpe_model.py \ - --lang-dir ./$lang_dir \ - --text $lang_dir/text - fi - - if [ ! -f $lang_dir/text_words_segmentation ]; then - python3 ./local/text2segments.py \ - --input-file ./data/lang_char/text \ - --output-file $lang_dir/text_words_segmentation - - cat ./data/lang_bpe_500/transcript_words.txt \ - >> $lang_dir/text_words_segmentation - fi - - cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt - - if [ ! -f $lang_dir/words.txt ]; then - python3 ./local/prepare_words.py \ - --input-file $lang_dir/words_no_ids.txt \ - --output-file $lang_dir/words.txt - fi - - if [ ! -f $lang_dir/bbpe.model ]; then - ./local/train_bbpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/text - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bbpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bbpe.model - fi - done -fi - diff --git a/egs/multi_zh_en/ASR/shared b/egs/multi_zh_en/ASR/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/multi_zh_en/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py deleted file mode 100644 index 489b38e65..000000000 --- a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1,387 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=300.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl diff --git a/egs/multi_zh_en/ASR/zipformer/beam_search.py b/egs/multi_zh_en/ASR/zipformer/beam_search.py deleted file mode 120000 index 8e2c0a65c..000000000 --- a/egs/multi_zh_en/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py deleted file mode 100755 index e21e8f052..000000000 --- a/egs/multi_zh_en/ASR/zipformer/decode.py +++ /dev/null @@ -1,851 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params - -from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_2000/bbpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bbpe_2000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-tal-csasr", - type=str2bool, - default=False, - help="Whether to use TAL-CSASR training data.", - ) - - parser.add_argument( - "--use-librispeech", - type=str2bool, - default=False, - help="Whether to use LibriSpeech training data.", - ) - - parser.add_argument( - "--use-aishell2", - type=str2bool, - default=False, - help="Whether to use Aishell-2 training data.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode( - byte_encode(tokenize_by_CJK_char(supervisions["text"])) - ), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [tokenize_by_CJK_char(str(text)).split() for text in texts] - # print(texts) - # exit() - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets_cuts = multi_dataset.test_cuts() - - test_sets = test_sets_cuts.keys() - test_dl = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dl): - logging.info(f"Start decoding test set: {test_set}") - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh_en/ASR/zipformer/decode_stream.py b/egs/multi_zh_en/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/multi_zh_en/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/decoder.py b/egs/multi_zh_en/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/multi_zh_en/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/encoder_interface.py b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/multi_zh_en/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx.py b/egs/multi_zh_en/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/multi_zh_en/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export.py b/egs/multi_zh_en/ASR/zipformer/export.py deleted file mode 100755 index fbd9ce0dd..000000000 --- a/egs/multi_zh_en/ASR/zipformer/export.py +++ /dev/null @@ -1,541 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -(1) Export to torchscript model using torch.jit.script() - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bbpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 \ - --jit 1 - -It will generate a file `jit_script.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("jit_script.pt")`. - -Check ./jit_pretrained.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens data/lang_bbpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 \ - --jit 1 - -It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. -You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. - -Check ./jit_pretrained_streaming.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bbpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bbpe_2000/tokens.txt \ - --epoch 20 \ - --avg 1 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -- For non-streaming model: - -To use the generated file with `zipformer/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bbpe_2000/bpe.model - -- For streaming model: - -To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - - # simulated streaming decoding - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bbpe_2000/bpe.model - - # chunk-wise streaming decoding - ./zipformer/streaming_decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bbpe_2000/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -- non-streaming model: -https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ - # You will find the pre-trained models in exp dir -""" - -import argparse -import logging -import re -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor, nn -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, str2bool - - -def num_tokens( - token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") -) -> int: - """Return the number of tokens excluding those from - disambiguation symbols. - - Caution: - 0 is not a token ID so it is excluded from the return value. - """ - symbols = token_table.symbols - ans = [] - for s in symbols: - if not disambig_pattern.match(s): - ans.append(token_table[s]) - num_tokens = len(ans) - if 0 in ans: - num_tokens -= 1 - return num_tokens - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bbpe_2000/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named jit_script.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class EncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Args: - features: (N, T, C) - feature_lengths: (N,) - """ - x, x_lens = self.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -class StreamingEncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """Streaming forward for encoder_embed and encoder. - - Args: - features: (N, T, C) - feature_lengths: (N,) - states: a list of Tensors - - Returns encoder outputs, output lengths, and updated states. - """ - chunk_size = self.chunk_size - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lengths, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - - # Wrap encoder and encoder_embed as a module - if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) - chunk_size = model.encoder.chunk_size - left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" - else: - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - filename = "jit_script.pt" - - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - model.save(str(params.exp_dir / filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py deleted file mode 100755 index 68111fad7..000000000 --- a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) use the checkpoint exp_dir/epoch-xxx.pt -./zipformer/generate_averaged_model.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp - -It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. - -(2) use the checkpoint exp_dir/checkpoint-iter.pt -./zipformer/generate_averaged_model.py \ - --iter 22000 \ - --avg 5 \ - --exp-dir ./zipformer/exp - -It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. -""" - - -import argparse -from pathlib import Path - -import k2 -import torch -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - print("Script started") - - device = torch.device("cpu") - print(f"Device: {device}") - - symbol_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = symbol_table[""] - params.unk_id = symbol_table[""] - params.vocab_size = len(symbol_table) - - print("About to create model") - model = get_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py deleted file mode 120000 index 9a8da5844..000000000 --- a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 120000 index 1962351e9..000000000 --- a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/joiner.py b/egs/multi_zh_en/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/multi_zh_en/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/model.py b/egs/multi_zh_en/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/multi_zh_en/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py deleted file mode 100644 index 1155a3dcc..000000000 --- a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Dict - -from lhotse import CutSet, load_manifest_lazy - - -class MultiDataset: - def __init__(self, args: argparse.Namespace): - """ - Args: - manifest_dir: - It is expected to contain the following files: - - aishell2_cuts_train.jsonl.gz - """ - self.fbank_dir = Path(args.manifest_dir) - self.use_tal_csasr = args.use_tal_csasr - self.use_librispeech = args.use_librispeech - self.use_aishell2 = args.use_aishell2 - - def train_cuts(self) -> CutSet: - logging.info("About to get multidataset train cuts") - - # AISHELL-2 - if self.use_aishell2: - logging.info("Loading Aishell-2 in lazy mode") - aishell_2_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_train.jsonl.gz" - ) - - # TAL-CSASR - if self.use_tal_csasr: - logging.info("Loading TAL-CSASR in lazy mode") - tal_csasr_cuts = load_manifest_lazy( - self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz" - ) - - # LibriSpeech - if self.use_librispeech: - logging.info("Loading LibriSpeech in lazy mode") - train_clean_100_cuts = self.train_clean_100_cuts() - train_clean_360_cuts = self.train_clean_360_cuts() - train_other_500_cuts = self.train_other_500_cuts() - - if self.use_tal_csasr and self.use_librispeech and self.use_aishell2: - return CutSet.mux( - aishell_2_cuts, - train_clean_100_cuts, - train_clean_360_cuts, - train_other_500_cuts, - tal_csasr_cuts, - weights=[ - len(aishell_2_cuts), - len(train_clean_100_cuts), - len(train_clean_360_cuts), - len(train_other_500_cuts), - len(tal_csasr_cuts), - ], - ) - elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2: - return CutSet.mux( - aishell_2_cuts, - train_clean_100_cuts, - train_clean_360_cuts, - train_other_500_cuts, - weights=[ - len(aishell_2_cuts), - len(train_clean_100_cuts), - len(train_clean_360_cuts), - len(train_other_500_cuts), - ], - ) - elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2: - return CutSet.mux( - aishell_2_cuts, - tal_csasr_cuts, - weights=[ - len(aishell_2_cuts), - len(tal_csasr_cuts), - ], - ) - elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2: - return CutSet.mux( - train_clean_100_cuts, - train_clean_360_cuts, - train_other_500_cuts, - tal_csasr_cuts, - weights=[ - len(train_clean_100_cuts), - len(train_clean_360_cuts), - len(train_other_500_cuts), - len(tal_csasr_cuts), - ], - ) - else: - raise NotImplementedError( - f"""Not implemented for - use_aishell2: {self.use_aishell2} - use_librispeech: {self.use_librispeech} - use_tal_csasr: {self.use_tal_csasr}""" - ) - - def dev_cuts(self) -> CutSet: - logging.info("About to get multidataset dev cuts") - - # AISHELL-2 - logging.info("Loading Aishell-2 DEV set in lazy mode") - aishell2_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" - ) - - # LibriSpeech - dev_clean_cuts = self.dev_clean_cuts() - dev_other_cuts = self.dev_other_cuts() - - logging.info("Loading TAL-CSASR set in lazy mode") - tal_csasr_dev_cuts = load_manifest_lazy( - self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" - ) - - return CutSet.mux( - aishell2_dev_cuts, - dev_clean_cuts, - dev_other_cuts, - tal_csasr_dev_cuts, - weights=[ - len(aishell2_dev_cuts), - len(dev_clean_cuts), - len(dev_other_cuts), - len(tal_csasr_dev_cuts), - ], - ) - - def test_cuts(self) -> Dict[str, CutSet]: - logging.info("About to get multidataset test cuts") - - # AISHELL-2 - if self.use_aishell2: - logging.info("Loading Aishell-2 set in lazy mode") - aishell2_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_test.jsonl.gz" - ) - aishell2_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" - ) - - # LibriSpeech - if self.use_librispeech: - test_clean_cuts = self.test_clean_cuts() - test_other_cuts = self.test_other_cuts() - - logging.info("Loading TAL-CSASR set in lazy mode") - tal_csasr_test_cuts = load_manifest_lazy( - self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz" - ) - tal_csasr_dev_cuts = load_manifest_lazy( - self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" - ) - - test_cuts = { - "tal_csasr_test": tal_csasr_test_cuts, - "tal_csasr_dev": tal_csasr_dev_cuts, - } - - if self.use_aishell2: - test_cuts.update( - { - "aishell-2_test": aishell2_test_cuts, - "aishell-2_dev": aishell2_dev_cuts, - } - ) - if self.use_librispeech: - test_cuts.update( - { - "librispeech_test_clean": test_clean_cuts, - "librispeech_test_other": test_other_cuts, - } - ) - return test_cuts - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" - ) diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_check.py b/egs/multi_zh_en/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/multi_zh_en/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_decode.py b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py deleted file mode 120000 index 0573b88c5..000000000 --- a/egs/multi_zh_en/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 120000 index cfea104c2..000000000 --- a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/optim.py b/egs/multi_zh_en/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/multi_zh_en/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py deleted file mode 100755 index 2fcde550b..000000000 --- a/egs/multi_zh_en/ASR/zipformer/pretrained.py +++ /dev/null @@ -1,379 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bbpe_2000/tokens.txt \ - --epoch 23 \ - --avg 1 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bbpe_2000/tokens.txt \ - --epoch 23 \ - --avg 1 - -Usage of this script: - -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bbpe_2000/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bbpe_2000/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bbpe_2000/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bbpe_2000/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bbpe_2000/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bbpe_2000/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - fast_beam_search_one_best, - greedy_search_batch, - modified_beam_search, -) -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall import smart_byte_decode - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to byte-level bpe model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/multi_zh_en/ASR/zipformer/scaling.py b/egs/multi_zh_en/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/multi_zh_en/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/scaling_converter.py b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/multi_zh_en/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py deleted file mode 100755 index 7b9bd2d6c..000000000 --- a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,869 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import AsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - multi_dataset = MultiDataset(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets_cuts = multi_dataset.test_cuts() - - test_sets = test_sets_cuts.keys() - test_cuts = [test_sets_cuts[k] for k in test_sets] - for test_set, test_cut in zip(test_sets, test_cuts): - logging.info(f"Decoding {test_set}") - test_cut = test_cut.filter(remove_short_utt) - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh_en/ASR/zipformer/subsampling.py b/egs/multi_zh_en/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/multi_zh_en/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py deleted file mode 100755 index 04bb41214..000000000 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ /dev/null @@ -1,1416 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from multi_dataset import MultiDataset -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import byte_encode, diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, - tokenize_by_CJK_char, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_2000/bbpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--use-tal-csasr", - type=str2bool, - default=False, - help="Whether to use TAL-CSASR training data.", - ) - - parser.add_argument( - "--use-librispeech", - type=str2bool, - default=False, - help="Whether to use LibriSpeech training data.", - ) - - parser.add_argument( - "--use-aishell2", - type=str2bool, - default=False, - help="Whether to use Aishell-2 training data.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args) - - train_cuts = multi_dataset.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - def tokenize_and_encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = byte_encode(tokenize_by_CJK_char(text)) - c.supervisions[0].text = text - return c - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_cuts = train_cuts.map(tokenize_and_encode_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = data_module.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = multi_dataset.dev_cuts() - valid_dl = data_module.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh_en/ASR/zipformer/zipformer.py b/egs/multi_zh_en/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/multi_zh_en/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/must_c/ST/local/compute_fbank_musan.py b/egs/must_c/ST/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/must_c/ST/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/must_c/ST/local/compute_fbank_must_c.py b/egs/must_c/ST/local/compute_fbank_must_c.py deleted file mode 100755 index 84de099d1..000000000 --- a/egs/must_c/ST/local/compute_fbank_must_c.py +++ /dev/null @@ -1,155 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - -""" -This file computes fbank features of the MuST-C dataset. -It looks for manifests in the directory "in_dir" and write -generated features to "out_dir". -""" -import argparse -import logging -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - FeatureSet, - LilcomChunkyWriter, - load_manifest, -) - -from icefall.utils import str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--in-dir", - type=Path, - required=True, - help="Input manifest directory", - ) - - parser.add_argument( - "--out-dir", - type=Path, - required=True, - help="Output directory where generated fbank features are saved.", - ) - - parser.add_argument( - "--tgt-lang", - type=str, - required=True, - help="Target language, e.g., zh, de, fr.", - ) - - parser.add_argument( - "--num-jobs", - type=int, - default=1, - help="Number of jobs for computing features", - ) - - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="""True to enable speed perturb with factors 0.9 and 1.1 on - the train subset. False (by default) to disable speed perturb. - """, - ) - - return parser.parse_args() - - -def compute_fbank_must_c( - in_dir: Path, - out_dir: Path, - tgt_lang: str, - num_jobs: int, - perturb_speed: bool, -): - out_dir.mkdir(parents=True, exist_ok=True) - - extractor = Fbank(FbankConfig(num_mel_bins=80)) - - parts = ["dev", "tst-COMMON", "tst-HE", "train"] - - prefix = "must_c" - suffix = "jsonl.gz" - for p in parts: - logging.info(f"Processing {p}") - - cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}" - if perturb_speed and p == "train": - cuts_path += "_sp" - - cuts_path += ".jsonl.gz" - - if Path(cuts_path).is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz" - supervisions_filename = ( - in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz" - ) - assert recordings_filename.is_file(), recordings_filename - assert supervisions_filename.is_file(), supervisions_filename - cut_set = CutSet.from_manifests( - recordings=load_manifest(recordings_filename), - supervisions=load_manifest(supervisions_filename), - ) - if perturb_speed and p == "train": - logging.info("Speed perturbing for the train dataset") - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp" - else: - storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}" - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=storage_path, - num_jobs=num_jobs, - storage_type=LilcomChunkyWriter, - ) - - logging.info("About to split cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - logging.info(f"Saved to {cuts_path}") - - -def main(): - args = get_args() - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - logging.info(vars(args)) - assert args.in_dir.is_dir(), args.in_dir - - compute_fbank_must_c( - in_dir=args.in_dir, - out_dir=args.out_dir, - tgt_lang=args.tgt_lang, - num_jobs=args.num_jobs, - perturb_speed=args.perturb_speed, - ) - - -if __name__ == "__main__": - main() diff --git a/egs/must_c/ST/local/get_text.py b/egs/must_c/ST/local/get_text.py deleted file mode 100755 index f7b5816a8..000000000 --- a/egs/must_c/ST/local/get_text.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) -""" -This file prints the text field of supervisions from cutset to the console -""" - -import argparse -from pathlib import Path - -from lhotse import load_manifest_lazy - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "manifest", - type=Path, - help="Input manifest", - ) - return parser.parse_args() - - -def main(): - args = get_args() - assert args.manifest.is_file(), args.manifest - - cutset = load_manifest_lazy(args.manifest) - for c in cutset: - for sup in c.supervisions: - print(sup.text) - - -if __name__ == "__main__": - main() diff --git a/egs/must_c/ST/local/get_words.py b/egs/must_c/ST/local/get_words.py deleted file mode 100755 index b32925099..000000000 --- a/egs/must_c/ST/local/get_words.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) -""" -This file generates words.txt from the given transcript file. -""" - -import argparse -from pathlib import Path - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "transcript", - type=Path, - help="Input transcript file", - ) - return parser.parse_args() - - -def main(): - args = get_args() - assert args.transcript.is_file(), args.transcript - - word_set = set() - with open(args.transcript) as f: - for line in f: - words = line.strip().split() - for w in words: - word_set.add(w) - - # Note: reserved* should be kept in sync with ./local/prepare_lang_bpe.py - reserved1 = ["", "!SIL", "", ""] - reserved2 = ["#0", "", ""] - - for w in reserved1 + reserved2: - assert w not in word_set, w - - words = sorted(list(word_set)) - words = reserved1 + words + reserved2 - - for i, w in enumerate(words): - print(w, i) - - -if __name__ == "__main__": - main() diff --git a/egs/must_c/ST/local/normalize_punctuation.py b/egs/must_c/ST/local/normalize_punctuation.py deleted file mode 100644 index efd47e091..000000000 --- a/egs/must_c/ST/local/normalize_punctuation.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) -import re - - -def normalize_punctuation(s: str, lang: str) -> str: - """ - This function implements - https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/normalize-punctuation.perl - - Args: - s: - A string to be normalized. - lang: - The language to which `s` belongs - Returns: - Return a normalized string. - """ - # s/\r//g; - s = re.sub("\r", "", s) - - # remove extra spaces - # s/\(/ \(/g; - s = re.sub("\(", " (", s) # add a space before ( - - # s/\)/\) /g; s/ +/ /g; - s = re.sub("\)", ") ", s) # add a space after ) - s = re.sub(" +", " ", s) # convert multiple spaces to one - - # s/\) ([\.\!\:\?\;\,])/\)$1/g; - s = re.sub("\) ([\.\!\:\?\;\,])", r")\1", s) - - # s/\( /\(/g; - s = re.sub("\( ", "(", s) # remove space after ( - - # s/ \)/\)/g; - s = re.sub(" \)", ")", s) # remove space before ) - - # s/(\d) \%/$1\%/g; - s = re.sub("(\d) \%", r"\1%", s) # remove space between a digit and % - - # s/ :/:/g; - s = re.sub(" :", ":", s) # remove space before : - - # s/ ;/;/g; - s = re.sub(" ;", ";", s) # remove space before ; - - # normalize unicode punctuation - # s/\`/\'/g; - s = re.sub("`", "'", s) # replace ` with ' - - # s/\'\'/ \" /g; - s = re.sub("''", '"', s) # replace '' with " - - # s/„/\"/g; - s = re.sub("„", '"', s) # replace „ with " - - # s/“/\"/g; - s = re.sub("“", '"', s) # replace “ with " - - # s/”/\"/g; - s = re.sub("”", '"', s) # replace ” with " - - # s/–/-/g; - s = re.sub("–", "-", s) # replace – with - - - # s/—/ - /g; s/ +/ /g; - s = re.sub("—", " - ", s) - s = re.sub(" +", " ", s) # convert multiple spaces to one - - # s/´/\'/g; - s = re.sub("´", "'", s) - - # s/([a-z])‘([a-z])/$1\'$2/gi; - s = re.sub("([a-z])‘([a-z])", r"\1'\2", s, flags=re.IGNORECASE) - - # s/([a-z])’([a-z])/$1\'$2/gi; - s = re.sub("([a-z])’([a-z])", r"\1'\2", s, flags=re.IGNORECASE) - - # s/‘/\'/g; - s = re.sub("‘", "'", s) - - # s/‚/\'/g; - s = re.sub("‚", "'", s) - - # s/’/\"/g; - s = re.sub("’", '"', s) - - # s/''/\"/g; - s = re.sub("''", '"', s) - - # s/´´/\"/g; - s = re.sub("´´", '"', s) - - # s/…/.../g; - s = re.sub("…", "...", s) - - # French quotes - - # s/ « / \"/g; - s = re.sub(" « ", ' "', s) - - # s/« /\"/g; - s = re.sub("« ", '"', s) - - # s/«/\"/g; - s = re.sub("«", '"', s) - - # s/ » /\" /g; - s = re.sub(" » ", '" ', s) - - # s/ »/\"/g; - s = re.sub(" »", '"', s) - - # s/»/\"/g; - s = re.sub("»", '"', s) - - # handle pseudo-spaces - - # s/ \%/\%/g; - s = re.sub(" %", r"%", s) - - # s/nº /nº /g; - s = re.sub("nº ", "nº ", s) - - # s/ :/:/g; - s = re.sub(" :", ":", s) - - # s/ ºC/ ºC/g; - s = re.sub(" ºC", " ºC", s) - - # s/ cm/ cm/g; - s = re.sub(" cm", " cm", s) - - # s/ \?/\?/g; - s = re.sub(" \?", "\?", s) - - # s/ \!/\!/g; - s = re.sub(" \!", "\!", s) - - # s/ ;/;/g; - s = re.sub(" ;", ";", s) - - # s/, /, /g; s/ +/ /g; - s = re.sub(", ", ", ", s) - s = re.sub(" +", " ", s) - - if lang == "en": - # English "quotation," followed by comma, style - # s/\"([,\.]+)/$1\"/g; - s = re.sub('"([,\.]+)', r'\1"', s) - elif lang in ("cs", "cz"): - # Czech is confused - pass - else: - # German/Spanish/French "quotation", followed by comma, style - # s/,\"/\",/g; - s = re.sub(',"', '",', s) - - # s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence - s = re.sub('(\.+)"(\s*[^<])', r'"\1\2', s) - - if lang in ("de", "es", "cz", "cs", "fr"): - # s/(\d) (\d)/$1,$2/g; - s = re.sub("(\d) (\d)", r"\1,\2", s) - else: - # s/(\d) (\d)/$1.$2/g; - s = re.sub("(\d) (\d)", r"\1.\2", s) - - return s diff --git a/egs/must_c/ST/local/prepare_lang.py b/egs/must_c/ST/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/must_c/ST/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/must_c/ST/local/prepare_lang_bpe.py b/egs/must_c/ST/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/must_c/ST/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/must_c/ST/local/preprocess_must_c.py b/egs/must_c/ST/local/preprocess_must_c.py deleted file mode 100755 index 1ba282bf4..000000000 --- a/egs/must_c/ST/local/preprocess_must_c.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) -""" -This script normalizes transcripts from supervisions. - -Usage: - ./local/preprocess_must_c.py \ - --manifest-dir ./data/manifests/v1.0/ \ - --tgt-lang de -""" - -import argparse -import logging -import re -from functools import partial -from pathlib import Path - -from lhotse.recipes.utils import read_manifests_if_cached -from normalize_punctuation import normalize_punctuation -from remove_non_native_characters import remove_non_native_characters -from remove_punctuation import remove_punctuation - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--manifest-dir", - type=Path, - required=True, - help="Manifest directory", - ) - parser.add_argument( - "--tgt-lang", - type=str, - required=True, - help="Target language, e.g., zh, de, fr.", - ) - return parser.parse_args() - - -def preprocess_must_c(manifest_dir: Path, tgt_lang: str): - normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang) - remove_non_native_characters_lang = partial( - remove_non_native_characters, lang=tgt_lang - ) - - prefix = "must_c" - suffix = "jsonl.gz" - parts = ["dev", "tst-COMMON", "tst-HE", "train"] - for p in parts: - logging.info(f"Processing {p}") - name = f"en-{tgt_lang}_{p}" - - # norm: normalization - # rm: remove punctuation - dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz" - if dst_name.is_file(): - logging.info(f"{dst_name} exists - skipping") - continue - - manifests = read_manifests_if_cached( - dataset_parts=name, - output_dir=manifest_dir, - prefix=prefix, - suffix=suffix, - types=("supervisions",), - ) - if name not in manifests: - raise RuntimeError(f"Processing {p} failed.") - - supervisions = manifests[name]["supervisions"] - supervisions = supervisions.transform_text(normalize_punctuation_lang) - supervisions = supervisions.transform_text(remove_punctuation) - supervisions = supervisions.transform_text(lambda x: x.lower()) - supervisions = supervisions.transform_text(remove_non_native_characters_lang) - supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x)) - - supervisions.to_file(dst_name) - - -def main(): - args = get_args() - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - logging.info(vars(args)) - assert args.manifest_dir.is_dir(), args.manifest_dir - - preprocess_must_c( - manifest_dir=args.manifest_dir, - tgt_lang=args.tgt_lang, - ) - - -if __name__ == "__main__": - main() diff --git a/egs/must_c/ST/local/remove_non_native_characters.py b/egs/must_c/ST/local/remove_non_native_characters.py deleted file mode 100755 index f61fbd16b..000000000 --- a/egs/must_c/ST/local/remove_non_native_characters.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - -import re - - -def remove_non_native_characters(s: str, lang: str): - if lang == "de": - # ä -> ae - # ö -> oe - # ü -> ue - # ß -> ss - - s = re.sub("ä", "ae", s) - s = re.sub("ö", "oe", s) - s = re.sub("ü", "ue", s) - s = re.sub("ß", "ss", s) - # keep only a-z and spaces - # note: ' is removed - s = re.sub(r"[^a-z\s]", "", s) - - return s diff --git a/egs/must_c/ST/local/remove_punctuation.py b/egs/must_c/ST/local/remove_punctuation.py deleted file mode 100644 index 723946ec3..000000000 --- a/egs/must_c/ST/local/remove_punctuation.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) -import re -import string - - -def remove_punctuation(s: str) -> str: - """ - It implements https://github.com/espnet/espnet/blob/master/utils/remove_punctuation.pl - """ - - # Remove punctuation except apostrophe - # s//spacemark/g; # for scoring - s = re.sub("", "spacemark", s) - - # s/'/apostrophe/g; - s = re.sub("'", "apostrophe", s) - - # s/[[:punct:]]//g; - s = s.translate(str.maketrans("", "", string.punctuation)) - # string punctuation returns the following string - # !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ - # See - # https://stackoverflow.com/questions/265960/best-way-to-strip-punctuation-from-a-string - - # s/apostrophe/'/g; - s = re.sub("apostrophe", "'", s) - - # s/spacemark//g; # for scoring - s = re.sub("spacemark", "", s) - - # remove whitespace - # s/\s+/ /g; - s = re.sub("\s+", " ", s) - - # s/^\s+//; - s = re.sub("^\s+", "", s) - - # s/\s+$//; - s = re.sub("\s+$", "", s) - - return s diff --git a/egs/must_c/ST/local/test_normalize_punctuation.py b/egs/must_c/ST/local/test_normalize_punctuation.py deleted file mode 100755 index 9079858c8..000000000 --- a/egs/must_c/ST/local/test_normalize_punctuation.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - -from normalize_punctuation import normalize_punctuation - - -def test_normalize_punctuation(): - # s/\r//g; - s = "a\r\nb\r\n" - n = normalize_punctuation(s, lang="en") - assert "\r" not in n - assert len(s) - 2 == len(n), (len(s), len(n)) - - # s/\(/ \(/g; - s = "(ab (c" - n = normalize_punctuation(s, lang="en") - assert n == " (ab (c", n - - # s/\)/\) /g; - s = "a)b c)" - n = normalize_punctuation(s, lang="en") - assert n == "a) b c) " - - # s/ +/ /g; - s = " a b c d " - n = normalize_punctuation(s, lang="en") - assert n == " a b c d " - - # s/\) ([\.\!\:\?\;\,])/\)$1/g; - for i in ".!:?;,": - s = f"a) {i}" - n = normalize_punctuation(s, lang="en") - assert n == f"a){i}" - - # s/\( /\(/g; - s = "a( b" - n = normalize_punctuation(s, lang="en") - assert n == "a (b", n - - # s/ \)/\)/g; - s = "ab ) a" - n = normalize_punctuation(s, lang="en") - assert n == "ab) a", n - - # s/(\d) \%/$1\%/g; - s = "1 %a" - n = normalize_punctuation(s, lang="en") - assert n == "1%a", n - - # s/ :/:/g; - s = "a :" - n = normalize_punctuation(s, lang="en") - assert n == "a:", n - - # s/ ;/;/g; - s = "a ;" - n = normalize_punctuation(s, lang="en") - assert n == "a;", n - - # s/\`/\'/g; - s = "`a`" - n = normalize_punctuation(s, lang="en") - assert n == "'a'", n - - # s/\'\'/ \" /g; - s = "''a''" - n = normalize_punctuation(s, lang="en") - assert n == '"a"', n - - # s/„/\"/g; - s = '„a"' - n = normalize_punctuation(s, lang="en") - assert n == '"a"', n - - # s/“/\"/g; - s = "“a„" - n = normalize_punctuation(s, lang="en") - assert n == '"a"', n - - # s/”/\"/g; - s = "“a”" - n = normalize_punctuation(s, lang="en") - assert n == '"a"', n - - # s/–/-/g; - s = "a–b" - n = normalize_punctuation(s, lang="en") - assert n == "a-b", n - - # s/—/ - /g; s/ +/ /g; - s = "a—b" - n = normalize_punctuation(s, lang="en") - assert n == "a - b", n - - # s/´/\'/g; - s = "a´b" - n = normalize_punctuation(s, lang="en") - assert n == "a'b", n - - # s/([a-z])‘([a-z])/$1\'$2/gi; - for i in "‘’": - s = f"a{i}B" - n = normalize_punctuation(s, lang="en") - assert n == "a'B", n - - s = f"A{i}B" - n = normalize_punctuation(s, lang="en") - assert n == "A'B", n - - s = f"A{i}b" - n = normalize_punctuation(s, lang="en") - assert n == "A'b", n - - # s/‘/\'/g; - # s/‚/\'/g; - for i in "‘‚": - s = f"a{i}b" - n = normalize_punctuation(s, lang="en") - assert n == "a'b", n - - # s/’/\"/g; - s = "’" - n = normalize_punctuation(s, lang="en") - assert n == '"', n - - # s/''/\"/g; - s = "''" - n = normalize_punctuation(s, lang="en") - assert n == '"', n - - # s/´´/\"/g; - s = "´´" - n = normalize_punctuation(s, lang="en") - assert n == '"', n - - # s/…/.../g; - s = "…" - n = normalize_punctuation(s, lang="en") - assert n == "...", n - - # s/ « / \"/g; - s = "a « b" - n = normalize_punctuation(s, lang="en") - assert n == 'a "b', n - - # s/« /\"/g; - s = "a « b" - n = normalize_punctuation(s, lang="en") - assert n == 'a "b', n - - # s/«/\"/g; - s = "a«b" - n = normalize_punctuation(s, lang="en") - assert n == 'a"b', n - - # s/ » /\" /g; - s = " » " - n = normalize_punctuation(s, lang="en") - assert n == '" ', n - - # s/ »/\"/g; - s = " »" - n = normalize_punctuation(s, lang="en") - assert n == '"', n - - # s/»/\"/g; - s = "»" - n = normalize_punctuation(s, lang="en") - assert n == '"', n - - # s/ \%/\%/g; - s = " %" - n = normalize_punctuation(s, lang="en") - assert n == "%", n - - # s/ :/:/g; - s = " :" - n = normalize_punctuation(s, lang="en") - assert n == ":", n - - # s/(\d) (\d)/$1.$2/g; - s = "2 3" - n = normalize_punctuation(s, lang="en") - assert n == "2.3", n - - # s/(\d) (\d)/$1,$2/g; - s = "2 3" - n = normalize_punctuation(s, lang="de") - assert n == "2,3", n - - -def main(): - test_normalize_punctuation() - - -if __name__ == "__main__": - main() diff --git a/egs/must_c/ST/local/test_remove_non_native_characters.py b/egs/must_c/ST/local/test_remove_non_native_characters.py deleted file mode 100755 index ecf8569cf..000000000 --- a/egs/must_c/ST/local/test_remove_non_native_characters.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - -from remove_non_native_characters import remove_non_native_characters - - -def test_remove_non_native_characters(): - s = "Ich heiße xxx好的01 fangjun".lower() - n = remove_non_native_characters(s, lang="de") - assert n == "ich heisse xxx fangjun", n - - s = "äÄ".lower() - n = remove_non_native_characters(s, lang="de") - assert n == "aeae", n - - s = "öÖ".lower() - n = remove_non_native_characters(s, lang="de") - assert n == "oeoe", n - - s = "üÜ".lower() - n = remove_non_native_characters(s, lang="de") - assert n == "ueue", n - - -if __name__ == "__main__": - test_remove_non_native_characters() diff --git a/egs/must_c/ST/local/test_remove_punctuation.py b/egs/must_c/ST/local/test_remove_punctuation.py deleted file mode 100755 index a4f318550..000000000 --- a/egs/must_c/ST/local/test_remove_punctuation.py +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env python3 - -from remove_punctuation import remove_punctuation - - -def test_remove_punctuation(): - s = "a,b'c!#" - n = remove_punctuation(s) - assert n == "ab'c", n - - s = " ab " # remove leading and trailing spaces - n = remove_punctuation(s) - assert n == "ab", n - - -if __name__ == "__main__": - test_remove_punctuation() diff --git a/egs/must_c/ST/local/train_bpe_model.py b/egs/must_c/ST/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/must_c/ST/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/must_c/ST/local/validate_bpe_lexicon.py b/egs/must_c/ST/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/must_c/ST/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/must_c/ST/prepare.sh b/egs/must_c/ST/prepare.sh deleted file mode 100755 index d16bb3d0b..000000000 --- a/egs/must_c/ST/prepare.sh +++ /dev/null @@ -1,173 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=10 -stage=0 -stop_stage=100 - -version=v1.0 -tgt_lang=de -dl_dir=$PWD/download - -must_c_dir=$dl_dir/must-c/$version/en-$tgt_lang/data - -# We assume dl_dir (download dir) contains the following -# directories and files. -# - $dl_dir/must-c/$version/en-$tgt_lang/data/{dev,train,tst-COMMON,tst-HE} -# -# Please go to https://ict.fbk.eu/must-c-releases/ -# to download and untar the dataset if you have not already done this. - -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate -# data/lang_bpe_${tgt_lang}_xxx -# data/lang_bpe_${tgt_lang}_yyy -# if the array contains xxx, yyy -vocab_sizes=( - # 5000 - # 2000 - # 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ ! -d $must_c_dir ]; then - log "$must_c_dir does not exist" - exit 1 -fi - -for d in dev train tst-COMMON tst-HE; do - if [ ! -d $must_c_dir/$d ]; then - log "$must_c_dir/$d does not exist!" - exit 1 - fi -done - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download musan" - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - if [ ! -e data/manifests/.musan.done ]; then - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare must-c $version manifest for target language $tgt_lang" - mkdir -p data/manifests/$version - if [ ! -e data/manifests/$version/.${tgt_lang}.manifests.done ]; then - lhotse prepare must-c \ - -j $nj \ - --tgt-lang $tgt_lang \ - $dl_dir/must-c/$version/ \ - data/manifests/$version/ - - touch data/manifests/$version/.${tgt_lang}.manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Text normalization for $version with target language $tgt_lang" - if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then - ./local/preprocess_must_c.py \ - --manifest-dir ./data/manifests/$version/ \ - --tgt-lang $tgt_lang - touch ./data/manifests/$version/.$tgt_lang.norm.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - mkdir -p data/fbank - if [ ! -e data/fbank/.musan.done ]; then - ./local/compute_fbank_musan.py - touch data/fbank/.musan.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for $version with target language $tgt_lang" - mkdir -p data/fbank/$version/ - if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then - ./local/compute_fbank_must_c.py \ - --in-dir ./data/manifests/$version/ \ - --out-dir ./data/fbank/$version/ \ - --tgt-lang $tgt_lang \ - --num-jobs $nj - - ./local/compute_fbank_must_c.py \ - --in-dir ./data/manifests/$version/ \ - --out-dir ./data/fbank/$version/ \ - --tgt-lang $tgt_lang \ - --num-jobs $nj \ - --perturb-speed 1 - - touch data/fbank/$version/.$tgt_lang.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/ - mkdir -p $lang_dir - if [ ! -f $lang_dir/transcript_words.txt ]; then - ./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/words.txt ]; then - ./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - done -fi diff --git a/egs/must_c/ST/shared b/egs/must_c/ST/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/must_c/ST/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/ptb/LM/README.md b/egs/ptb/LM/README.md deleted file mode 100644 index 7629a950d..000000000 --- a/egs/ptb/LM/README.md +++ /dev/null @@ -1,18 +0,0 @@ -## Description - -(Note: the experiments here are only about language modeling) - -ptb is short for Penn Treebank. - - -About the Penn Treebank corpus: - - This corpus is free for research purposes - - ptb.train.txt: train set - - ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training) - - ptb.test.txt: test set for reporting perplexity - -You can download the dataset from one of the following URLs: - -- https://github.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage -- http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz -- https://deepai.org/dataset/penn-treebank diff --git a/egs/ptb/LM/local/prepare_lm_training_data.py b/egs/ptb/LM/local/prepare_lm_training_data.py deleted file mode 120000 index abc00d421..000000000 --- a/egs/ptb/LM/local/prepare_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py deleted file mode 100755 index bed3856e4..000000000 --- a/egs/ptb/LM/local/sort_lm_training_data.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file takes as input the filename of LM training data -generated by ./local/prepare_lm_training_data.py and sorts -it by sentence length. - -Sentence length equals to the number of BPE tokens in a sentence. -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import numpy as np -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--in-lm-data", - type=str, - help="Input LM training data, e.g., data/bpe_500/lm_data.pt", - ) - - parser.add_argument( - "--out-lm-data", - type=str, - help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", - ) - - parser.add_argument( - "--out-statistics", - type=str, - help="Statistics about LM training data., data/bpe_500/statistics.txt", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - in_lm_data = Path(args.in_lm_data) - out_lm_data = Path(args.out_lm_data) - assert in_lm_data.is_file(), f"{in_lm_data}" - if out_lm_data.is_file(): - logging.warning(f"{out_lm_data} exists - skipping") - return - data = torch.load(in_lm_data) - words2bpe = data["words"] - sentences = data["sentences"] - sentence_lengths = data["sentence_lengths"] - - num_sentences = sentences.dim0 - assert num_sentences == sentence_lengths.numel(), ( - num_sentences, - sentence_lengths.numel(), - ) - - indices = torch.argsort(sentence_lengths, descending=True) - - sorted_sentences = sentences[indices.to(torch.int32)] - sorted_sentence_lengths = sentence_lengths[indices] - - # Check that sentences are ordered by length - assert num_sentences == sorted_sentences.dim0, ( - num_sentences, - sorted_sentences.dim0, - ) - - cur = None - for i in range(num_sentences): - word_ids = sorted_sentences[i] - token_ids = words2bpe[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - if cur is not None: - assert cur >= token_ids.numel(), (cur, token_ids.numel()) - - cur = token_ids.numel() - assert cur == sorted_sentence_lengths[i] - - data["sentences"] = sorted_sentences - data["sentence_lengths"] = sorted_sentence_lengths - torch.save(data, args.out_lm_data) - logging.info(f"Saved to {args.out_lm_data}") - - statistics = Path(args.out_statistics) - - # Write statistics - num_words = sorted_sentences.numel() - num_tokens = sentence_lengths.sum().item() - max_sentence_length = sentence_lengths[indices[0]] - min_sentence_length = sentence_lengths[indices[-1]] - - step = 10 - hist, bins = np.histogram( - sentence_lengths.numpy(), - bins=np.arange(1, max_sentence_length + step, step), - ) - - histogram = np.stack((bins[:-1], hist)).transpose() - - with open(statistics, "w") as f: - f.write(f"num_sentences: {num_sentences}\n") - f.write(f"num_words: {num_words}\n") - f.write(f"num_tokens: {num_tokens}\n") - f.write(f"max_sentence_length: {max_sentence_length}\n") - f.write(f"min_sentence_length: {min_sentence_length}\n") - f.write("histogram:\n") - f.write(" bin count percent\n") - for row in histogram: - f.write( - f"{int(row[0]):>5} {int(row[1]):>5} " - f"{100.*row[1]/num_sentences:.3f}%\n" - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py deleted file mode 100755 index 3790045fa..000000000 --- a/egs/ptb/LM/local/test_prepare_lm_training_data.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from pathlib import Path - -import sentencepiece as spm -import torch - - -def main(): - lm_training_data = Path("./data/bpe_500/lm_data.pt") - bpe_model = Path("./data/bpe_500/bpe.model") - if not lm_training_data.exists(): - logging.warning(f"{lm_training_data} does not exist - skipping") - return - - if not bpe_model.exists(): - logging.warning(f"{bpe_model} does not exist - skipping") - return - - sp = spm.SentencePieceProcessor() - sp.load(str(bpe_model)) - - data = torch.load(lm_training_data) - words2bpe = data["words"] - sentences = data["sentences"] - - ss = [] - unk = sp.decode(sp.unk_id()).strip() - for i in range(10): - s = sp.decode(words2bpe[sentences[i]].values.tolist()) - s = s.replace(unk, "") - ss.append(s) - - for s in ss: - print(s) - # You can compare the output with the first 10 lines of ptb.train.txt - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ptb/LM/local/train_bpe_model.py b/egs/ptb/LM/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/ptb/LM/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh deleted file mode 100755 index 69fab999a..000000000 --- a/egs/ptb/LM/prepare.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download -# The following files will be downloaded to $dl_dir -# - ptb.train.txt -# - ptb.valid.txt -# - ptb.test.txt - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/bpe_xxx, data/bpe_yyy -# if the array contains xxx, yyy -vocab_sizes=( - 500 - # 1000 - # 2000 - # 5000 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data -mkdir -p $dl_dir - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download data" - - # Caution: The downloaded data has already been normalized for LM training. - - if [ ! -f $dl_dir/.complete ]; then - url=http://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data - wget --directory-prefix $dl_dir $url/ptb.train.txt - wget --directory-prefix $dl_dir $url/ptb.valid.txt - wget --directory-prefix $dl_dir $url/ptb.test.txt - touch $dl_dir/.complete - fi -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Train BPE model" - - # Caution: You have to use the same bpe model for training your acoustic model - # Caution: You have to use the same bpe model for training your acoustic model - # Caution: You have to use the same bpe model for training your acoustic model - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $dl_dir/ptb.train.txt - done -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Generate LM training data" - # Note: ptb.train.txt has already been normalized - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $dl_dir/ptb.train.txt \ - --lm-archive $out_dir/lm_data.pt - - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $dl_dir/ptb.valid.txt \ - --lm-archive $out_dir/lm_data-valid.pt - - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $dl_dir/ptb.test.txt \ - --lm-archive $out_dir/lm_data-test.pt - done -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Sort LM training data" - # Sort LM training data generated in stage 1 - # by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of BPE tokens - # in a sentence. - - for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-valid.pt \ - --out-lm-data $out_dir/sorted_lm_data-valid.pt \ - --out-statistics $out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-test.pt \ - --out-lm-data $out_dir/sorted_lm_data-test.pt \ - --out-statistics $out_dir/statistics-test.txt - done -fi diff --git a/egs/ptb/LM/rnn_lm b/egs/ptb/LM/rnn_lm deleted file mode 120000 index 87f29771e..000000000 --- a/egs/ptb/LM/rnn_lm +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/rnn_lm \ No newline at end of file diff --git a/egs/ptb/LM/shared b/egs/ptb/LM/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/ptb/LM/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/ptb/LM/train-rnn-lm.sh b/egs/ptb/LM/train-rnn-lm.sh deleted file mode 100755 index cb70b7856..000000000 --- a/egs/ptb/LM/train-rnn-lm.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env bash - -# Please run ./prepare.sh first - -stage=-1 -stop_stage=100 - -# Number of GPUs to use for training -world_size=1 - -# Number of epochs to train -num_epochs=20 - -# Use this epoch for computing ppl -use_epoch=19 - -# number of models to average for computing ppl -use_avg=2 - -exp_dir=./my-rnnlm-exp - -. shared/parse_options.sh || exit 1 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Training RNN LM" - - ./rnn_lm/train.py \ - --exp-dir $exp_dir \ - --start-epoch 0 \ - --num-epochs $num_epochs \ - --world-size $world_size \ - --use-fp16 0 \ - --vocab-size 500 \ - --lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \ - --lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \ - --embedding-dim 800 \ - --hidden-dim 200 \ - --num-layers 2 \ - --tie-weights false \ - --batch-size 50 -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Computing perplexity" - - ./rnn_lm/compute_perplexity.py \ - --exp-dir $exp_dir \ - --epoch $use_epoch \ - --avg $use_avg \ - --vocab-size 500 \ - --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \ - --embedding-dim 800 \ - --hidden-dim 200 \ - --num-layers 2 \ - --tie-weights false \ - --batch-size 50 -fi diff --git a/egs/reazonspeech/ASR/README.md b/egs/reazonspeech/ASR/README.md deleted file mode 100644 index ad5c15de3..000000000 --- a/egs/reazonspeech/ASR/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Introduction - - - -**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio. - - - -The dataset is available on Hugging Face. For more details, please visit: - -- Dataset: https://huggingface.co/datasets/reazon-research/reazonspeech -- Paper: https://research.reazon.jp/_static/reazonspeech_nlp2023.pdf - - - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - - - -There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -| ---------------------------------------- | -------------------- | ------------------ | ------------------------------------------------- | -| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe | - -The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. - diff --git a/egs/reazonspeech/ASR/RESULTS.md b/egs/reazonspeech/ASR/RESULTS.md deleted file mode 100644 index 92610d75b..000000000 --- a/egs/reazonspeech/ASR/RESULTS.md +++ /dev/null @@ -1,87 +0,0 @@ -## Results - -### Zipformer - -#### Non-streaming - -##### large-scaled model, number of model parameters: 159337842, i.e., 159.34 M - -| decoding method | In-Distribution CER | JSUT | CommonVoice | TEDx | comment | -| :------------------: | :-----------------: | :--: | :---------: | :---: | :----------------: | -| greedy search | 4.2 | 6.7 | 7.84 | 17.9 | --epoch 39 --avg 7 | -| modified beam search | 4.13 | 6.77 | 7.69 | 17.82 | --epoch 39 --avg 7 | - -The training command is: - -```shell -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 40 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-large \ - --causal 0 \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --lang data/lang_char \ - --max-duration 1600 -``` - -The decoding command is: - -```shell -./zipformer/decode.py \ - --epoch 40 \ - --avg 16 \ - --exp-dir zipformer/exp-large \ - --max-duration 600 \ - --causal 0 \ - --decoding-method greedy_search \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --lang data/lang_char \ - --blank-penalty 0 -``` - -#### Streaming - -We have not completed evaluation of our models yet and will add evaluation results here once it's completed. - -The training command is: -```shell -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 40 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-large \ - --causal 1 \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --lang data/lang_char \ - --max-duration 1600 -``` - -The decoding command is: - -```shell -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp-large \ - --lang data/lang_char \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 -``` - diff --git a/egs/reazonspeech/ASR/local/compute_fbank_reazonspeech.py b/egs/reazonspeech/ASR/local/compute_fbank_reazonspeech.py deleted file mode 100644 index af7841406..000000000 --- a/egs/reazonspeech/ASR/local/compute_fbank_reazonspeech.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import os -from pathlib import Path -from typing import List, Tuple - -import torch - -# fmt: off -from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527 - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - RecordingSet, - SupervisionSet, -) - -# fmt: on - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -RNG_SEED = 42 -concat_params = {"gap": 1.0, "maxlen": 10.0} - - -def make_cutset_blueprints( - manifest_dir: Path, -) -> List[Tuple[str, CutSet]]: - cut_sets = [] - - # Create test dataset - logging.info("Creating test cuts.") - cut_sets.append( - ( - "test", - CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "reazonspeech_recordings_test.jsonl.gz" - ), - supervisions=SupervisionSet.from_file( - manifest_dir / "reazonspeech_supervisions_test.jsonl.gz" - ), - ), - ) - ) - - # Create dev dataset - logging.info("Creating dev cuts.") - cut_sets.append( - ( - "dev", - CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "reazonspeech_recordings_dev.jsonl.gz" - ), - supervisions=SupervisionSet.from_file( - manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz" - ), - ), - ) - ) - - # Create train dataset - logging.info("Creating train cuts.") - cut_sets.append( - ( - "train", - CutSet.from_manifests( - recordings=RecordingSet.from_file( - manifest_dir / "reazonspeech_recordings_train.jsonl.gz" - ), - supervisions=SupervisionSet.from_file( - manifest_dir / "reazonspeech_supervisions_train.jsonl.gz" - ), - ), - ) - ) - return cut_sets - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("-m", "--manifest-dir", type=Path) - return parser.parse_args() - - -def main(): - args = get_args() - - extractor = Fbank(FbankConfig(num_mel_bins=80)) - num_jobs = min(16, os.cpu_count()) - - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - if (args.manifest_dir / ".reazonspeech-fbank.done").exists(): - logging.info( - "Previous fbank computed for ReazonSpeech found. " - f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank." - ) - return - else: - cut_sets = make_cutset_blueprints(args.manifest_dir) - for part, cut_set in cut_sets: - logging.info(f"Processing {part}") - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - num_jobs=num_jobs, - storage_path=(args.manifest_dir / f"feats_{part}").as_posix(), - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz") - - logging.info("All fbank computed for ReazonSpeech.") - (args.manifest_dir / ".reazonspeech-fbank.done").touch() - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/local/display_manifest_statistics.py b/egs/reazonspeech/ASR/local/display_manifest_statistics.py deleted file mode 100644 index ace1dd73f..000000000 --- a/egs/reazonspeech/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -from pathlib import Path - -from lhotse import CutSet, load_manifest - -ARGPARSE_DESCRIPTION = """ -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in -pruned_transducer_stateless5/train.py for usage. -""" - - -def get_parser(): - parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") - - return parser.parse_args() - - -def main(): - args = get_parser() - - for part in ["train", "dev"]: - path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz" - cuts: CutSet = load_manifest(path) - - print("\n---------------------------------\n") - print(path.name + ":") - cuts.describe() - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/local/prepare_lang_char.py b/egs/reazonspeech/ASR/local/prepare_lang_char.py deleted file mode 100644 index 19c5f4a31..000000000 --- a/egs/reazonspeech/ASR/local/prepare_lang_char.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help=( - "Name of lang dir. " - "If not set, this will default to lang_char_{trans-mode}" - ), - ) - - return parser.parse_args() - - -def main(): - args = get_args() - logging.basicConfig( - format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), - level=logging.INFO, - ) - - sysdef_string = set(["", "", "", " "]) - - token_set = set() - logging.info(f"Creating vocabulary from {args.train_cut}.") - train_cut: CutSet = CutSet.from_file(args.train_cut) - for cut in train_cut: - for sup in cut.supervisions: - token_set.update(sup.text) - - token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] - args.lang_dir.mkdir(parents=True, exist_ok=True) - (args.lang_dir / "tokens.txt").write_text( - "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) - ) - - (args.lang_dir / "lang_type").write_text("char") - logging.info("Done.") - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/local/utils/asr_datamodule.py b/egs/reazonspeech/ASR/local/utils/asr_datamodule.py deleted file mode 100644 index e70370760..000000000 --- a/egs/reazonspeech/ASR/local/utils/asr_datamodule.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class ReazonSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/dev/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=False, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - - transforms = [] - input_transforms = [] - - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" - ) diff --git a/egs/reazonspeech/ASR/local/utils/tokenizer.py b/egs/reazonspeech/ASR/local/utils/tokenizer.py deleted file mode 100644 index ba71cff89..000000000 --- a/egs/reazonspeech/ASR/local/utils/tokenizer.py +++ /dev/null @@ -1,252 +0,0 @@ -import argparse -from pathlib import Path -from typing import Callable, List, Union - -import sentencepiece as spm -from k2 import SymbolTable - - -class Tokenizer: - text2word: Callable[[str], List[str]] - - @staticmethod - def add_arguments(parser: argparse.ArgumentParser): - group = parser.add_argument_group(title="Lang related options") - group.add_argument("--lang", type=Path, help="Path to lang directory.") - - group.add_argument( - "--lang-type", - type=str, - default=None, - help=( - "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " - "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" - ), - ) - - @staticmethod - def Load(lang_dir: Path, lang_type="", oov=""): - - if not lang_type: - assert (lang_dir / "lang_type").exists(), "lang_type not specified." - lang_type = (lang_dir / "lang_type").read_text().strip() - - tokenizer = None - - if lang_type == "bpe": - assert ( - lang_dir / "bpe.model" - ).exists(), f"No BPE .model could be found in {lang_dir}." - tokenizer = spm.SentencePieceProcessor() - tokenizer.Load(str(lang_dir / "bpe.model")) - elif lang_type == "char": - tokenizer = CharTokenizer(lang_dir, oov=oov) - else: - raise NotImplementedError(f"{lang_type} not supported at the moment.") - - return tokenizer - - load = Load - - def PieceToId(self, piece: str) -> int: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - piece_to_id = PieceToId - - def IdToPiece(self, id: int) -> str: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - id_to_piece = IdToPiece - - def GetPieceSize(self) -> int: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - get_piece_size = GetPieceSize - - def __len__(self) -> int: - return self.get_piece_size() - - def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def EncodeAsIds(self, input: str) -> List[int]: - return self.EncodeAsIdsBatch([input])[0] - - def EncodeAsPieces(self, input: str) -> List[str]: - return self.EncodeAsPiecesBatch([input])[0] - - def Encode( - self, input: Union[str, List[str]], out_type=int - ) -> Union[List, List[List]]: - if not input: - return [] - - if isinstance(input, list): - if out_type is int: - return self.EncodeAsIdsBatch(input) - if out_type is str: - return self.EncodeAsPiecesBatch(input) - - if out_type is int: - return self.EncodeAsIds(input) - if out_type is str: - return self.EncodeAsPieces(input) - - encode = Encode - - def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def DecodeIds(self, input: List[int]) -> str: - return self.DecodeIdsBatch([input])[0] - - def DecodePieces(self, input: List[str]) -> str: - return self.DecodePiecesBatch([input])[0] - - def Decode( - self, - input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], - ) -> Union[List[str], str]: - - if not input: - return "" - - if isinstance(input, int): - return self.id_to_piece(input) - elif isinstance(input, str): - raise TypeError( - "Unlike spm.SentencePieceProcessor, cannot decode from type str." - ) - - if isinstance(input[0], list): - if not input[0] or isinstance(input[0][0], int): - return self.DecodeIdsBatch(input) - - if isinstance(input[0][0], str): - return self.DecodePiecesBatch(input) - - if isinstance(input[0], int): - return self.DecodeIds(input) - if isinstance(input[0], str): - return self.DecodePieces(input) - - raise RuntimeError("Unknown input type") - - decode = Decode - - def SplitBatch(self, input: List[str]) -> List[List[str]]: - raise NotImplementedError( - "You need to implement this function in the child class." - ) - - def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: - if isinstance(input, list): - return self.SplitBatch(input) - elif isinstance(input, str): - return self.SplitBatch([input])[0] - raise RuntimeError("Unknown input type") - - split = Split - - -class CharTokenizer(Tokenizer): - def __init__(self, lang_dir: Path, oov="", sep=""): - assert ( - lang_dir / "tokens.txt" - ).exists(), f"tokens.txt could not be found in {lang_dir}." - token_table = SymbolTable.from_file(lang_dir / "tokens.txt") - assert ( - "#0" not in token_table - ), "This tokenizer does not support disambig symbols." - self._id2sym = token_table._id2sym - self._sym2id = token_table._sym2id - self.oov = oov - self.oov_id = self._sym2id[oov] - self.sep = sep - if self.sep: - self.text2word = lambda x: x.split(self.sep) - else: - self.text2word = lambda x: list(x.replace(" ", "")) - - def piece_to_id(self, piece: str) -> int: - try: - return self._sym2id[piece] - except KeyError: - return self.oov_id - - def id_to_piece(self, id: int) -> str: - return self._id2sym[id] - - def get_piece_size(self) -> int: - return len(self._sym2id) - - def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: - return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] - - def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: - return [ - [i if i in self._sym2id else self.oov for i in self.text2word(text)] - for text in input - ] - - def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: - return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] - - def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: - return [self.sep.join(text) for text in input] - - def SplitBatch(self, input: List[str]) -> List[List[str]]: - return [self.text2word(text) for text in input] - - -def test_CharTokenizer(): - test_single_string = "こんにちは" - test_multiple_string = [ - "今日はいい天気ですよね", - "諏訪湖は綺麗でしょう", - "这在词表外", - "分かち 書き に し た 文章 です", - "", - ] - test_empty_string = "" - sp = Tokenizer.load(Path("lang_char"), "char", oov="") - splitter = sp.split - print(sp.encode(test_single_string, out_type=str)) - print(sp.encode(test_single_string, out_type=int)) - print(sp.encode(test_multiple_string, out_type=str)) - print(sp.encode(test_multiple_string, out_type=int)) - print(sp.encode(test_empty_string, out_type=str)) - print(sp.encode(test_empty_string, out_type=int)) - print(sp.decode(sp.encode(test_single_string, out_type=str))) - print(sp.decode(sp.encode(test_single_string, out_type=int))) - print(sp.decode(sp.encode(test_multiple_string, out_type=str))) - print(sp.decode(sp.encode(test_multiple_string, out_type=int))) - print(sp.decode(sp.encode(test_empty_string, out_type=str))) - print(sp.decode(sp.encode(test_empty_string, out_type=int))) - print(splitter(test_single_string)) - print(splitter(test_multiple_string)) - print(splitter(test_empty_string)) - - -if __name__ == "__main__": - test_CharTokenizer() diff --git a/egs/reazonspeech/ASR/local/validate_manifest.py b/egs/reazonspeech/ASR/local/validate_manifest.py deleted file mode 100644 index 7f67c64b6..000000000 --- a/egs/reazonspeech/ASR/local/validate_manifest.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut -- Supervision time bounds are within cut time bounds - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def validate_one_supervision_per_cut(c: Cut): - if len(c.supervisions) != 1: - raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") - - -def validate_supervision_and_cut_time_bounds(c: Cut): - s = c.supervisions[0] - - # Removed because when the cuts were trimmed from supervisions, - # the start time of the supervision can be lesser than cut start time. - # https://github.com/lhotse-speech/lhotse/issues/813 - # if s.start < c.start: - # raise ValueError( - # f"{c.id}: Supervision start time {s.start} is less " - # f"than cut start time {c.start}" - # ) - - if s.end > c.end: - raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" - ) - - -def main(): - args = get_args() - - manifest = Path(args.manifest) - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest(manifest) - assert isinstance(cut_set, CutSet) - - for c in cut_set: - validate_one_supervision_per_cut(c) - validate_supervision_and_cut_time_bounds(c) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/reazonspeech/ASR/prepare.sh b/egs/reazonspeech/ASR/prepare.sh deleted file mode 100755 index d5e0a9491..000000000 --- a/egs/reazonspeech/ASR/prepare.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/ReazonSpeech -# You can find FLAC files in this directory. -# You can download them from https://huggingface.co/datasets/reazon-research/reazonspeech -# -# - $dl_dir/dataset.json -# The metadata of the ReazonSpeech dataset. - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "Running prepare.sh" - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/ReazonSpeech, - # you can create a symlink - # - # ln -sfv /path/to/ReazonSpeech $dl_dir/ReazonSpeech - # - if [ ! -d $dl_dir/ReazonSpeech/downloads ]; then - # Download small-v1 by default. - lhotse download reazonspeech --subset small-v1 $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare ReazonSpeech manifest" - # We assume that you have downloaded the ReazonSpeech corpus - # to $dl_dir/ReazonSpeech - mkdir -p data/manifests - if [ ! -e data/manifests/.reazonspeech.done ]; then - lhotse prepare reazonspeech -j $nj $dl_dir/ReazonSpeech data/manifests - touch data/manifests/.reazonspeech.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute ReazonSpeech fbank" - if [ ! -e data/manifests/.reazonspeech-validated.done ]; then - python local/compute_fbank_reazonspeech.py --manifest-dir data/manifests - python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_train.jsonl.gz - python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_dev.jsonl.gz - python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_test.jsonl.gz - touch data/manifests/.reazonspeech-validated.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare ReazonSpeech lang_char" - python local/prepare_lang_char.py data/manifests/reazonspeech_cuts_train.jsonl.gz -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Show manifest statistics" - python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt - cat data/manifests/manifest_statistics.txt -fi \ No newline at end of file diff --git a/egs/reazonspeech/ASR/shared b/egs/reazonspeech/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/reazonspeech/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/asr_datamodule.py b/egs/reazonspeech/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index a48591198..000000000 --- a/egs/reazonspeech/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/beam_search.py b/egs/reazonspeech/ASR/zipformer/beam_search.py deleted file mode 120000 index 8e2c0a65c..000000000 --- a/egs/reazonspeech/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/ctc_decode.py b/egs/reazonspeech/ASR/zipformer/ctc_decode.py deleted file mode 120000 index faa8bd562..000000000 --- a/egs/reazonspeech/ASR/zipformer/ctc_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/decode.py b/egs/reazonspeech/ASR/zipformer/decode.py deleted file mode 100755 index cdd2145f2..000000000 --- a/egs/reazonspeech/ASR/zipformer/decode.py +++ /dev/null @@ -1,1076 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from tokenizer import Tokenizer -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.text2word(sp.decode(hyp))) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"-context-score-{params.context_score}" - return {prefix: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = sp.text2word(ref_text) - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # and are defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append((sp.encode(line.strip()), 0.0)) - context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - - for subdir in ["valid"]: - results_dict = decode_dataset( - dl=reazonspeech_corpus.test_dataloaders( - getattr(reazonspeech_corpus, f"{subdir}_cuts")() - ), - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - tot_err = save_results( - params=params, - test_set_name=subdir, - results_dict=results_dict, - ) - # with ( - # params.res_dir - # / ( - # f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" - # f"_{params.avg}_{params.epoch}.cer" - # ) - # ).open("w") as fout: - # if len(tot_err) == 1: - # fout.write(f"{tot_err[0][1]}") - # else: - # fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/zipformer/decode_stream.py b/egs/reazonspeech/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/reazonspeech/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/decoder.py b/egs/reazonspeech/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/reazonspeech/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py b/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py deleted file mode 100755 index 072679cfc..000000000 --- a/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py +++ /dev/null @@ -1,1261 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import math -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer_for_ncnn_export_only import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -LOG_EPS = math.log(1e-10) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=0, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - is_pnnx=True, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: # noqa - logging.error(e, exc_info=True) - display_and_save_batch(batch, params=params, sp=sp) - raise e - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - log_mode = logging.info - log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") - log_mode( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, master_port=params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 30.0: - logging.debug( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.info( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - train_cuts = reazonspeech_corpus.train_cuts() - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = reazonspeech_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = reazonspeech_corpus.valid_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) - - if params.start_batch <= 0 and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: Tokenizer, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - raise RuntimeError("Please don't use this file directly!") - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/zipformer/encoder_interface.py b/egs/reazonspeech/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/reazonspeech/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/export-onnx.py b/egs/reazonspeech/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/reazonspeech/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/export.py b/egs/reazonspeech/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/reazonspeech/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/generate_averaged_model.py b/egs/reazonspeech/ASR/zipformer/generate_averaged_model.py deleted file mode 120000 index 5a015ee6c..000000000 --- a/egs/reazonspeech/ASR/zipformer/generate_averaged_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/joiner.py b/egs/reazonspeech/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/reazonspeech/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/model.py b/egs/reazonspeech/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/reazonspeech/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/my_profile.py b/egs/reazonspeech/ASR/zipformer/my_profile.py deleted file mode 120000 index 3a90b2628..000000000 --- a/egs/reazonspeech/ASR/zipformer/my_profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/onnx_pretrained.py b/egs/reazonspeech/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/reazonspeech/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/optim.py b/egs/reazonspeech/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/reazonspeech/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/pretrained.py b/egs/reazonspeech/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/reazonspeech/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/scaling.py b/egs/reazonspeech/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/reazonspeech/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/scaling_converter.py b/egs/reazonspeech/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/reazonspeech/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/streaming_beam_search.py b/egs/reazonspeech/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py deleted file mode 100755 index 7e3199e09..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,900 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192 - -""" - -import argparse -import logging -import math -import os -import pdb -import subprocess as sp -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -from asr_datamodule import ReazonSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from tokenizer import Tokenizer -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - # pdb.set_trace() - # print(model) - # print(model.device) - # device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=model.device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=model.device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - # if decode_streams[i].done: - # finished_streams.append(i) - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - tokenizer: Tokenizer, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - tokenizer: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - tokenizer.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - # print("INSIDE LEN DECODE STREAMS") - # pdb.set_trace() - # print(model.device) - # test_device = model.device - # print("done") - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - # print('INSIDE FOR LOOP ') - # print(finished_streams) - - if not finished_streams: - print("No finished streams, breaking the loop") - break - - for i in sorted(finished_streams, reverse=True): - try: - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - tokenizer.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - except IndexError as e: - print(f"IndexError: {e}") - print(f"decode_streams length: {len(decode_streams)}") - print(f"finished_streams: {finished_streams}") - print(f"i: {i}") - continue - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - torch.cuda.synchronize() - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp_token = Tokenizer.load(params.lang, params.lang_type) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp_token.piece_to_id("") - params.unk_id = sp_token.piece_to_id("") - params.vocab_size = sp_token.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - - valid_cuts = reazonspeech_corpus.valid_cuts() - test_cuts = reazonspeech_corpus.test_cuts() - - test_sets = ["valid", "test"] - test_cuts = [valid_cuts, test_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - tokenizer=sp_token, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - # valid_cuts = reazonspeech_corpus.valid_cuts() - - # for valid_cut in valid_cuts: - # results_dict = decode_dataset( - # cuts=valid_cut, - # params=params, - # model=model, - # sp=sp, - # decoding_graph=decoding_graph, - # ) - # save_results( - # params=params, - # test_set_name="valid", - # results_dict=results_dict, - # ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/zipformer/subsampling.py b/egs/reazonspeech/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/reazonspeech/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/test_scaling.py b/egs/reazonspeech/ASR/zipformer/test_scaling.py deleted file mode 120000 index 715798436..000000000 --- a/egs/reazonspeech/ASR/zipformer/test_scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/test_subsampling.py b/egs/reazonspeech/ASR/zipformer/test_subsampling.py deleted file mode 120000 index bf0ee3d11..000000000 --- a/egs/reazonspeech/ASR/zipformer/test_subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/tokenizer.py b/egs/reazonspeech/ASR/zipformer/tokenizer.py deleted file mode 120000 index 958c99e85..000000000 --- a/egs/reazonspeech/ASR/zipformer/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py deleted file mode 100755 index 30bd3efba..000000000 --- a/egs/reazonspeech/ASR/zipformer/train.py +++ /dev/null @@ -1,1383 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.015, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - train_cuts = reazonspeech_corpus.train_cuts() - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = reazonspeech_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = reazonspeech_corpus.valid_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: Tokenizer, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/zipformer/zipformer.py b/egs/reazonspeech/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/reazonspeech/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/README.md b/egs/speech_llm/ASR_LLM/README.md old mode 100644 new mode 100755 diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md old mode 100644 new mode 100755 diff --git a/egs/speech_llm/ASR_LLM/assets/framework.png b/egs/speech_llm/ASR_LLM/assets/framework.png old mode 100644 new mode 100755 diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json old mode 100644 new mode 100755 diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py old mode 100644 new mode 100755 diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py old mode 100644 new mode 100755 diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt old mode 100644 new mode 100755 diff --git a/egs/speechio/ASR/README.md b/egs/speechio/ASR/README.md deleted file mode 100644 index 2675efd9b..000000000 --- a/egs/speechio/ASR/README.md +++ /dev/null @@ -1,15 +0,0 @@ - -# Introduction - -This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets. - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Pretrained Models - -The following table lists the pretrained models. - -| | Huggingface | Comment | -|---------------------------------------|--------------------|-----------------------------| -| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | | -| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training | diff --git a/egs/speechio/ASR/RESULTS.md b/egs/speechio/ASR/RESULTS.md deleted file mode 100644 index 3c556f74e..000000000 --- a/egs/speechio/ASR/RESULTS.md +++ /dev/null @@ -1,115 +0,0 @@ -## Results - -### SpeechIO Test Set Decoding Results - - - - -#### **Unlocked** SpeechIO test sets (ZH00001 ~ ZH00026) -| Rank 排名 | Model 模型 | CER 字错误率 | Date 时间 | -| --- | --- | --- | --- | -| 1 | ximalaya_api_zh | 1.72% | 2023.12 | -| 2 | aliyun_ftasr_api_zh | 1.85% | 2023.12 | -| 3 | microsoft_batch_zh | 2.40% | 2023.12 | -| 4 | bilibili_api_zh | 2.90% | 2023.09 | -| 5 | tencent_api_zh | 3.18% | 2023.12 | -| 6 | iflytek_lfasr_api_zh | 3.32% | 2023.12 | -| 7 | aispeech_api_zh | 3.62% | 2023.12 | -| 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 | -| 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 | -| 10 | **whisper-large-ft-v1-distill** | **4.71%** | 2024.04 | -| 11 | **zipformer (70Mb)** | **6.17%** | 2023.10 | -| 12 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | -| 13 | baidu_pro_api_zh | 7.29% | 2023.12 | - -Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67) - -For **whisper-large-ft-v1-distill**, instead of actually using distillation loss for training, the model structure and parameter initialization method from the [distill-whisper](https://arxiv.org/abs/2311.00430) paper were adopted: only the first and last layers of the decoder were retained. - -
Detail all models

- -| Model | Training Set | Note | -|----------------------------------------------------------------------------------------------------------|---------------|-----------------------------------------------------| -|[zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24)| multi-hans-zh | decoding with transducer head and blank penalty 2.0 | -|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs| -|[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs | -|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs| -|[whisper-large-ft-v1-distill](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1-distill)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 6 epochs| -

- - -
Detail all results (字错误率 CER %)

- -| Test Set ID | 测试场景&内容领域 | bilibili_api_zh (2023.09) | whisper-large-ft-v0 | whisper-large-ft-v1 | zipformer | -|----------------------|-------------------------------|-----------------|---------|-----------|-----------| -| Avg (01-26) | | 2.9 | 6.34 | 4.32 | 6.17 | -| SPEECHIO_ASR_ZH00001 | 新闻联播 | 0.54 | 1.42 | 1.09 | 1.37 | -| SPEECHIO_ASR_ZH00002 | 访谈 鲁豫有约 | 2.78 | 4.76 | 3.21 | 4.67 | -| SPEECHIO_ASR_ZH00003 | 电视节目 天下足球 | 0.81 | 2.17 | 1.70 | 2.71 | -| SPEECHIO_ASR_ZH00004 | 场馆演讲 罗振宇跨年 | 1.48 | 2.53 | 1.86 | 2.54 | -| SPEECHIO_ASR_ZH00005 | 在线教育 李永乐 科普 | 1.47 | 4.27 | 1.95 | 3.12 | -| SPEECHIO_ASR_ZH00006 | 直播 王者荣耀 张大仙&骚白 | 5.85 | 12.55 | 9.46 | 12.86 | -| SPEECHIO_ASR_ZH00007 | 直播 带货 李佳琪&薇娅 | 6.19 | 13.38 | 10.38 | 14.58 | -| SPEECHIO_ASR_ZH00008 | 线下培训 老罗语录 | 3.68 | 9.56 | 6.9 | 9.05 | -| SPEECHIO_ASR_ZH00009 | 播客 故事FM | 3.18 | 5.66 | 3.78 | 5.4 | -| SPEECHIO_ASR_ZH00010 | 播客 创业内幕 | 3.51 | 7.84 | 4.36 | 6.4 | -| SPEECHIO_ASR_ZH00011 | 在线教育 罗翔 刑法法考 | 1.77 | 3.22 | 2.40 | 3.12 | -| SPEECHIO_ASR_ZH00012 | 在线教育 张雪峰 考研 | 2.11 | 5.98 | 3.03 | 4.41 | -| SPEECHIO_ASR_ZH00013 | 短视频 影剪 谷阿莫&牛叔说电影 | 2.97 | 5.91 | 3.72 | 6.56 | -| SPEECHIO_ASR_ZH00014 | 短视频 美式&烹饪 | 3.56 | 6.03 | 4.92 | 8.14 | -| SPEECHIO_ASR_ZH00015 | 评书 单田芳 白眉大侠 | 4.72 | 8.77 | 7.92 | 9.1 | -| SPEECHIO_ASR_ZH00016 | 相声 德云社专场 | 3.01 | 5.24 | 4.15 | 5.59 | -| SPEECHIO_ASR_ZH00017 | 脱口秀 吐槽大会 | 2.93 | 7.05 | 3.04 | 5.17 | -| SPEECHIO_ASR_ZH00018 | 少儿卡通 小猪佩奇&熊出没 | 1.98 | 3.53 | 3.27 | 4.15 | -| SPEECHIO_ASR_ZH00019 | 体育赛事解说 NBA比赛 | 2.32 | 6.89 | 4.39 | 6.66 | -| SPEECHIO_ASR_ZH00020 | 纪录片 篮球人物 | 1.51 | 4.16 | 3.04 | 4.2 | -| SPEECHIO_ASR_ZH00021 | 短视频 汽车之家 汽车评测 | 1.75 | 4.77 | 2.69 | 4.17 | -| SPEECHIO_ASR_ZH00022 | 短视频 小艾大叔 豪宅带看 | 3.29 | 6.35 | 5.44 | 6.72 | -| SPEECHIO_ASR_ZH00023 | 短视频 开箱视频 Zeal&无聊开箱 | 2.18 | 8.99 | 4.08 | 7.94 | -| SPEECHIO_ASR_ZH00024 | 短视频 付老师 农业种植 | 4.80 | 10.81 | 6.06 | 8.64 | -| SPEECHIO_ASR_ZH00025 | 线下课堂 石国鹏 古希腊哲学 | 3.32 | 8.41 | 5.39 | 8.54 | -| SPEECHIO_ASR_ZH00026 | 广播电台节目 张震鬼故事 | 3.70 | 4.52 | 4.06 | 4.67 | -

- - -Command for decoding using fine-tuned whisper: -```bash -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper -ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch 999 --avg 1 \ - --start-index 0 --end-index 26 \ - --remove-whisper-encoder-input-length-restriction True \ - --manifest-dir data/fbank \ - --beam-size 1 --max-duration 50 -``` -Command for decoding using pretrained zipformer: -```bash -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 -cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 -git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "data/lang_bpe_2000/*" -ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt -ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data -wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt -mv words.txt ./data/lang_bpe_2000/ - -./zipformer/decode.py \ - --epoch 999 \ - --avg 1 \ - --blank-penalty 2.0 \ - --use-averaged-model false \ - --exp-dir ./zipformer/exp_pretrain \ - --max-duration 600 \ - --start-index 0 --end-index 26 \ - --manifest-dir data/fbank_kaldi \ - --decoding-method greedy_search -``` - -SpeechIO fbank features, decoding scripts, logs, and decoding results -are available at [part1]() and [part2](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1). diff --git a/egs/speechio/ASR/local/compute_fbank_speechio.py b/egs/speechio/ASR/local/compute_fbank_speechio.py deleted file mode 100644 index 5b3489a9f..000000000 --- a/egs/speechio/ASR/local/compute_fbank_speechio.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the ST-CMDS dataset. -It looks for manifests in the directory data/manifests/stcmds. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source. - - -def compute_fbank_speechio( - num_mel_bins: int = 80, - speed_perturb: bool = False, - fbank_dir: str = "data/fbank", - whisper_fbank: bool = False, -): - src_dir = Path("data/manifests") - output_dir = Path(fbank_dir) - num_jobs = min(8, os.cpu_count()) - - dataset_parts = [] - for i in range(SPEECHIO_TESTSET_INDEX + 1): - idx = f"{i}".zfill(2) - dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") - - prefix = "speechio" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - parser.add_argument( - "--fbank-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_speechio( - num_mel_bins=args.num_mel_bins, - fbank_dir=args.fbank_dir, - whisper_fbank=args.whisper_fbank, - ) diff --git a/egs/speechio/ASR/local/display_manifest_statistics.py b/egs/speechio/ASR/local/display_manifest_statistics.py deleted file mode 100644 index 0c803bfcd..000000000 --- a/egs/speechio/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,1162 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer_stateless/train.py -for usage. -""" - -SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source. - -from lhotse import load_manifest_lazy - - -def main(): - dataset_parts = [] - for i in range(SPEECHIO_TESTSET_INDEX + 1): - idx = f"{i}".zfill(2) - dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") - - prefix = "speechio" - suffix = "jsonl.gz" - - for partition in dataset_parts: - path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}" - cuts = load_manifest_lazy(path) - print( - f"===================Duration statistics of {partition}===================" - ) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -===================Duration statistics of SPEECHIO_ASR_ZH00000=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 879 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:36:09 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.6 │ -├───────────────────────────┼──────────┤ -│ std │ 2.0 │ -├───────────────────────────┼──────────┤ -│ min │ 1.7 │ -├───────────────────────────┼──────────┤ -│ 25% │ 5.0 │ -├───────────────────────────┼──────────┤ -│ 50% │ 6.5 │ -├───────────────────────────┼──────────┤ -│ 75% │ 8.1 │ -├───────────────────────────┼──────────┤ -│ 99% │ 11.2 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 11.6 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 12.2 │ -├───────────────────────────┼──────────┤ -│ max │ 12.5 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 879 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 879 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 879 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:36:09 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:36:09 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00001=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 5069 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 08:43:04 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.2 │ -├───────────────────────────┼──────────┤ -│ std │ 2.1 │ -├───────────────────────────┼──────────┤ -│ min │ 0.6 │ -├───────────────────────────┼──────────┤ -│ 25% │ 4.6 │ -├───────────────────────────┼──────────┤ -│ 50% │ 6.2 │ -├───────────────────────────┼──────────┤ -│ 75% │ 7.9 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.7 │ -├───────────────────────────┼──────────┤ -│ max │ 12.5 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 5069 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 5069 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 5069 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 08:43:04 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 08:43:04 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00002=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 2993 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:45:09 │ -├───────────────────────────┼──────────┤ -│ mean │ 3.3 │ -├───────────────────────────┼──────────┤ -│ std │ 1.5 │ -├───────────────────────────┼──────────┤ -│ min │ 0.4 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.2 │ -├───────────────────────────┼──────────┤ -│ 50% │ 3.1 │ -├───────────────────────────┼──────────┤ -│ 75% │ 4.3 │ -├───────────────────────────┼──────────┤ -│ 99% │ 7.3 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 7.8 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 9.1 │ -├───────────────────────────┼──────────┤ -│ max │ 11.8 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 2993 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 2993 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 2993 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:45:09 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:45:09 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00003=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1683 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:23:28 │ -├───────────────────────────┼──────────┤ -│ mean │ 5.1 │ -├───────────────────────────┼──────────┤ -│ std │ 1.4 │ -├───────────────────────────┼──────────┤ -│ min │ 2.4 │ -├───────────────────────────┼──────────┤ -│ 25% │ 4.0 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.9 │ -├───────────────────────────┼──────────┤ -│ 75% │ 6.0 │ -├───────────────────────────┼──────────┤ -│ 99% │ 9.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 9.4 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.8 │ -├───────────────────────────┼──────────┤ -│ max │ 14.2 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1683 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1683 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1683 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:23:28 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:23:28 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00004=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1311 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:49:16 │ -├───────────────────────────┼──────────┤ -│ mean │ 7.7 │ -├───────────────────────────┼──────────┤ -│ std │ 2.8 │ -├───────────────────────────┼──────────┤ -│ min │ 0.9 │ -├───────────────────────────┼──────────┤ -│ 25% │ 5.8 │ -├───────────────────────────┼──────────┤ -│ 50% │ 8.1 │ -├───────────────────────────┼──────────┤ -│ 75% │ 9.8 │ -├───────────────────────────┼──────────┤ -│ 99% │ 12.9 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 13.5 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 13.8 │ -├───────────────────────────┼──────────┤ -│ max │ 14.4 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1311 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1311 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1311 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:49:16 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:49:16 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00005=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 3148 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 04:22:47 │ -├───────────────────────────┼──────────┤ -│ mean │ 5.0 │ -├───────────────────────────┼──────────┤ -│ std │ 1.4 │ -├───────────────────────────┼──────────┤ -│ min │ 2.0 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.9 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.9 │ -├───────────────────────────┼──────────┤ -│ 75% │ 5.9 │ -├───────────────────────────┼──────────┤ -│ 99% │ 8.8 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 9.3 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.3 │ -├───────────────────────────┼──────────┤ -│ max │ 11.1 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 3148 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 3148 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 3148 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 04:22:47 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 04:22:47 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00006=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1561 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:39:33 │ -├───────────────────────────┼──────────┤ -│ mean │ 3.8 │ -├───────────────────────────┼──────────┤ -│ std │ 2.2 │ -├───────────────────────────┼──────────┤ -│ min │ 0.4 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.2 │ -├───────────────────────────┼──────────┤ -│ 50% │ 3.3 │ -├───────────────────────────┼──────────┤ -│ 75% │ 4.9 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.4 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 11.3 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 15.3 │ -├───────────────────────────┼──────────┤ -│ max │ 23.8 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1561 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1561 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1561 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:39:33 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:39:33 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00007=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 770 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 00:58:57 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.6 │ -├───────────────────────────┼──────────┤ -│ std │ 2.4 │ -├───────────────────────────┼──────────┤ -│ min │ 0.7 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.7 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.0 │ -├───────────────────────────┼──────────┤ -│ 75% │ 6.0 │ -├───────────────────────────┼──────────┤ -│ 99% │ 11.8 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 13.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 15.1 │ -├───────────────────────────┼──────────┤ -│ max │ 18.7 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 770 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 770 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 770 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 00:58:57 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 00:58:57 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00008=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 884 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:16:55 │ -├───────────────────────────┼──────────┤ -│ mean │ 5.2 │ -├───────────────────────────┼──────────┤ -│ std │ 2.3 │ -├───────────────────────────┼──────────┤ -│ min │ 1.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.5 │ -├───────────────────────────┼──────────┤ -│ 50% │ 5.0 │ -├───────────────────────────┼──────────┤ -│ 75% │ 6.4 │ -├───────────────────────────┼──────────┤ -│ 99% │ 11.3 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 12.7 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 16.2 │ -├───────────────────────────┼──────────┤ -│ max │ 18.5 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 884 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 884 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 884 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:16:55 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:16:55 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00009=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 3466 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 04:38:13 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.8 │ -├───────────────────────────┼──────────┤ -│ std │ 1.9 │ -├───────────────────────────┼──────────┤ -│ min │ 1.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.4 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.5 │ -├───────────────────────────┼──────────┤ -│ 75% │ 5.9 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.5 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 11.3 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 12.5 │ -├───────────────────────────┼──────────┤ -│ max │ 13.1 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 3466 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 3466 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 3466 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 04:38:13 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 04:38:13 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00010=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 2251 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 04:12:54 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.7 │ -├───────────────────────────┼──────────┤ -│ std │ 3.0 │ -├───────────────────────────┼──────────┤ -│ min │ 1.4 │ -├───────────────────────────┼──────────┤ -│ 25% │ 4.5 │ -├───────────────────────────┼──────────┤ -│ 50% │ 6.3 │ -├───────────────────────────┼──────────┤ -│ 75% │ 8.5 │ -├───────────────────────────┼──────────┤ -│ 99% │ 14.9 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 15.5 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 15.8 │ -├───────────────────────────┼──────────┤ -│ max │ 16.2 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 2251 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 2251 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 2251 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 04:12:54 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 04:12:54 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00011=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1053 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 03:27:12 │ -├───────────────────────────┼──────────┤ -│ mean │ 11.8 │ -├───────────────────────────┼──────────┤ -│ std │ 3.4 │ -├───────────────────────────┼──────────┤ -│ min │ 1.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 11.5 │ -├───────────────────────────┼──────────┤ -│ 50% │ 13.0 │ -├───────────────────────────┼──────────┤ -│ 75% │ 13.9 │ -├───────────────────────────┼──────────┤ -│ 99% │ 15.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 15.1 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 20.7 │ -├───────────────────────────┼──────────┤ -│ max │ 22.2 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1053 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1053 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1053 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 03:27:12 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 03:27:12 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00012=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1170 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 03:23:34 │ -├───────────────────────────┼──────────┤ -│ mean │ 10.4 │ -├───────────────────────────┼──────────┤ -│ std │ 3.5 │ -├───────────────────────────┼──────────┤ -│ min │ 0.8 │ -├───────────────────────────┼──────────┤ -│ 25% │ 8.0 │ -├───────────────────────────┼──────────┤ -│ 50% │ 11.5 │ -├───────────────────────────┼──────────┤ -│ 75% │ 13.2 │ -├───────────────────────────┼──────────┤ -│ 99% │ 15.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 15.1 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 15.7 │ -├───────────────────────────┼──────────┤ -│ max │ 20.3 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1170 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1170 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1170 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 03:23:34 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 03:23:34 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00013=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1321 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:46:41 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.8 │ -├───────────────────────────┼──────────┤ -│ std │ 1.5 │ -├───────────────────────────┼──────────┤ -│ min │ 0.9 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.8 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.8 │ -├───────────────────────────┼──────────┤ -│ 75% │ 5.8 │ -├───────────────────────────┼──────────┤ -│ 99% │ 8.5 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 9.1 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 9.5 │ -├───────────────────────────┼──────────┤ -│ max │ 9.7 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1321 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1321 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1321 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:46:41 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:46:41 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00014=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 856 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:00:39 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.3 │ -├───────────────────────────┼──────────┤ -│ std │ 1.8 │ -├───────────────────────────┼──────────┤ -│ min │ 0.8 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.9 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.1 │ -├───────────────────────────┼──────────┤ -│ 75% │ 5.5 │ -├───────────────────────────┼──────────┤ -│ 99% │ 8.5 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 9.2 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ max │ 11.1 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 856 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 856 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 856 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:00:39 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:00:39 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00015=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1168 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:08:52 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.6 │ -├───────────────────────────┼──────────┤ -│ std │ 2.0 │ -├───────────────────────────┼──────────┤ -│ min │ 1.2 │ -├───────────────────────────┼──────────┤ -│ 25% │ 5.3 │ -├───────────────────────────┼──────────┤ -│ 50% │ 6.8 │ -├───────────────────────────┼──────────┤ -│ 75% │ 8.2 │ -├───────────────────────────┼──────────┤ -│ 99% │ 9.9 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.1 │ -├───────────────────────────┼──────────┤ -│ max │ 15.5 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1168 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1168 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1168 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:08:52 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:08:52 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00016=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1201 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:00:46 │ -├───────────────────────────┼──────────┤ -│ mean │ 3.0 │ -├───────────────────────────┼──────────┤ -│ std │ 2.0 │ -├───────────────────────────┼──────────┤ -│ min │ 0.9 │ -├───────────────────────────┼──────────┤ -│ 25% │ 1.6 │ -├───────────────────────────┼──────────┤ -│ 50% │ 2.3 │ -├───────────────────────────┼──────────┤ -│ 75% │ 3.8 │ -├───────────────────────────┼──────────┤ -│ 99% │ 9.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 9.5 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 9.7 │ -├───────────────────────────┼──────────┤ -│ max │ 9.9 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1201 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1201 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1201 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:00:46 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:00:46 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00017=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1271 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:47:57 │ -├───────────────────────────┼──────────┤ -│ mean │ 5.1 │ -├───────────────────────────┼──────────┤ -│ std │ 2.2 │ -├───────────────────────────┼──────────┤ -│ min │ 1.0 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.3 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.9 │ -├───────────────────────────┼──────────┤ -│ 75% │ 6.8 │ -├───────────────────────────┼──────────┤ -│ 99% │ 9.7 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 10.0 │ -├───────────────────────────┼──────────┤ -│ max │ 10.4 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1271 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1271 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1271 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:47:57 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:47:57 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00018=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 899 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 00:51:12 │ -├───────────────────────────┼──────────┤ -│ mean │ 3.4 │ -├───────────────────────────┼──────────┤ -│ std │ 1.2 │ -├───────────────────────────┼──────────┤ -│ min │ 1.3 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.5 │ -├───────────────────────────┼──────────┤ -│ 50% │ 3.1 │ -├───────────────────────────┼──────────┤ -│ 75% │ 4.1 │ -├───────────────────────────┼──────────┤ -│ 99% │ 6.7 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 7.1 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 8.2 │ -├───────────────────────────┼──────────┤ -│ max │ 9.2 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 899 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 899 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 899 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 00:51:12 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 00:51:12 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00019=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 615 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 00:41:43 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.1 │ -├───────────────────────────┼──────────┤ -│ std │ 1.5 │ -├───────────────────────────┼──────────┤ -│ min │ 1.3 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.8 │ -├───────────────────────────┼──────────┤ -│ 50% │ 3.9 │ -├───────────────────────────┼──────────┤ -│ 75% │ 5.2 │ -├───────────────────────────┼──────────┤ -│ 99% │ 7.9 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 8.1 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 8.6 │ -├───────────────────────────┼──────────┤ -│ max │ 8.8 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 615 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 615 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 615 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 00:41:43 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 00:41:43 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00020=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1590 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:10:54 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.9 │ -├───────────────────────────┼──────────┤ -│ std │ 1.5 │ -├───────────────────────────┼──────────┤ -│ min │ 1.2 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.8 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.9 │ -├───────────────────────────┼──────────┤ -│ 75% │ 6.0 │ -├───────────────────────────┼──────────┤ -│ 99% │ 8.5 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 8.7 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 9.2 │ -├───────────────────────────┼──────────┤ -│ max │ 10.4 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1590 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1590 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1590 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:10:54 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:10:54 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00021=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1035 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:44:07 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.0 │ -├───────────────────────────┼──────────┤ -│ std │ 1.8 │ -├───────────────────────────┼──────────┤ -│ min │ 1.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 4.7 │ -├───────────────────────────┼──────────┤ -│ 50% │ 5.9 │ -├───────────────────────────┼──────────┤ -│ 75% │ 7.3 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.4 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.6 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 11.0 │ -├───────────────────────────┼──────────┤ -│ max │ 11.1 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1035 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1035 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1035 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:44:07 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:44:07 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00022=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1026 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:40:43 │ -├───────────────────────────┼──────────┤ -│ mean │ 5.9 │ -├───────────────────────────┼──────────┤ -│ std │ 2.2 │ -├───────────────────────────┼──────────┤ -│ min │ 0.9 │ -├───────────────────────────┼──────────┤ -│ 25% │ 4.4 │ -├───────────────────────────┼──────────┤ -│ 50% │ 5.8 │ -├───────────────────────────┼──────────┤ -│ 75% │ 7.1 │ -├───────────────────────────┼──────────┤ -│ 99% │ 12.1 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 12.7 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 13.9 │ -├───────────────────────────┼──────────┤ -│ max │ 14.0 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1026 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1026 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1026 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:40:43 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:40:43 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00023=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1528 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:06:51 │ -├───────────────────────────┼──────────┤ -│ mean │ 5.0 │ -├───────────────────────────┼──────────┤ -│ std │ 2.5 │ -├───────────────────────────┼──────────┤ -│ min │ 0.5 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.1 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.5 │ -├───────────────────────────┼──────────┤ -│ 75% │ 6.6 │ -├───────────────────────────┼──────────┤ -│ 99% │ 12.3 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 13.9 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 15.8 │ -├───────────────────────────┼──────────┤ -│ max │ 16.8 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1528 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1528 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1528 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:06:51 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:06:51 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00024=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1930 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:39:02 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.9 │ -├───────────────────────────┼──────────┤ -│ std │ 2.0 │ -├───────────────────────────┼──────────┤ -│ min │ 0.9 │ -├───────────────────────────┼──────────┤ -│ 25% │ 3.4 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.7 │ -├───────────────────────────┼──────────┤ -│ 75% │ 6.2 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.3 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.9 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 12.0 │ -├───────────────────────────┼──────────┤ -│ max │ 12.6 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1930 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1930 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1930 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:39:02 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:39:02 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00025=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1164 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:24:42 │ -├───────────────────────────┼──────────┤ -│ mean │ 4.4 │ -├───────────────────────────┼──────────┤ -│ std │ 1.9 │ -├───────────────────────────┼──────────┤ -│ min │ 0.9 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.9 │ -├───────────────────────────┼──────────┤ -│ 50% │ 4.1 │ -├───────────────────────────┼──────────┤ -│ 75% │ 5.6 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.4 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 10.9 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 12.5 │ -├───────────────────────────┼──────────┤ -│ max │ 13.0 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1164 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1164 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1164 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:24:42 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:24:42 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -===================Duration statistics of SPEECHIO_ASR_ZH00026=================== -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 1336 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 02:25:38 │ -├───────────────────────────┼──────────┤ -│ mean │ 6.5 │ -├───────────────────────────┼──────────┤ -│ std │ 2.3 │ -├───────────────────────────┼──────────┤ -│ min │ 0.5 │ -├───────────────────────────┼──────────┤ -│ 25% │ 4.9 │ -├───────────────────────────┼──────────┤ -│ 50% │ 6.8 │ -├───────────────────────────┼──────────┤ -│ 75% │ 8.3 │ -├───────────────────────────┼──────────┤ -│ 99% │ 10.4 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 11.9 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 12.9 │ -├───────────────────────────┼──────────┤ -│ max │ 13.3 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 1336 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 1336 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 1336 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 02:25:38 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 02:25:38 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ - -""" diff --git a/egs/speechio/ASR/local/normalize_results.py b/egs/speechio/ASR/local/normalize_results.py deleted file mode 100755 index 02277e2a8..000000000 --- a/egs/speechio/ASR/local/normalize_results.py +++ /dev/null @@ -1,165 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2024 Author: Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file uses speech io offcial pipline to normalize the decoding results. -https://github.com/SpeechColab/Leaderboard/blob/master/utils/textnorm_zh.py - -Usage: - python normalize_results.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm -""" - -import argparse -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import kaldialign -from speechio_norm import TextNorm - -from icefall.utils import store_transcripts, write_error_stats - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--model-log-dir", - type=str, - default="./recogs_whisper", - help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt", - ) - parser.add_argument( - "--output-log-dir", - type=str, - default="./results_whisper_norm", - help="The directory to store the normalized whisper logs", - ) - return parser - - -def save_results_with_speechio_text_norm( - res_dir: Path, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - normalizer = TextNorm() - # normlize items in results_dict - for key, results in results_dict.items(): - results_norm = [] - for item in results: - wav_name, ref, hyp = item - ref = normalizer(ref) - hyp = normalizer(hyp) - results_norm.append((wav_name, ref, hyp)) - results_dict[key] = results_norm - - test_set_wers = dict() - - suffix = "epoch-999-avg-1" - - for key, results in results_dict.items(): - recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - print(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - print("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - print(s) - - -def extract_hyp_ref_wavname(filename): - """ - 0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升'] - 0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升 - """ - hyps, refs, wav_name = [], [], [] - with open(filename, "r") as f: - for line in f: - if "ref" in line: - ref = line.split("ref=")[1].strip() - if ref[0] == "[": - ref = ref[2:-2] - list_elements = ref.split("', '") - ref = "".join(list_elements) - refs.append(ref) - elif "hyp" in line: - hyp = line.split("hyp=")[1].strip() - hyps.append(hyp) - wav_name.append(line.split(":")[0]) - return hyps, refs, wav_name - - -def get_filenames( - whisper_log_dir, - whisper_suffix="beam-search-epoch-999-avg-1", -): - results = [] - start_index, end_index = 0, 26 - dataset_parts = [] - for i in range(start_index, end_index + 1): - idx = f"{i}".zfill(2) - dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") - for partition in dataset_parts: - whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt" - results.append(whisper_filename) - return results - - -def main(): - parser = get_parser() - args = parser.parse_args() - # mkdir output_log_dir - Path(args.output_log_dir).mkdir(parents=True, exist_ok=True) - filenames = get_filenames(args.model_log_dir) - for filename in filenames: - hyps, refs, wav_name = extract_hyp_ref_wavname(filename) - partition_name = filename.split("/")[-1].split("-")[1] - - save_results_with_speechio_text_norm( - Path(args.output_log_dir), - partition_name, - {"norm": list(zip(wav_name, refs, hyps))}, - ) - - print(f"Processed {partition_name}") - - -if __name__ == "__main__": - main() diff --git a/egs/speechio/ASR/local/speechio_norm.py b/egs/speechio/ASR/local/speechio_norm.py deleted file mode 100755 index 6f3cd55b0..000000000 --- a/egs/speechio/ASR/local/speechio_norm.py +++ /dev/null @@ -1,1364 +0,0 @@ -#!/usr/bin/env python3 -# coding=utf-8 - -# Authors: -# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) -# 2019.9 - 2022 Jiayu DU -# -# requirements: -# - python 3.X -# notes: python 2.X WILL fail or produce misleading results - -import argparse -import csv -import os -import re -import string -import sys - -# ================================================================================ # -# basic constant -# ================================================================================ # -CHINESE_DIGIS = "零一二三四五六七八九" -BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" -BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" -SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" -SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" -LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" -LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" -SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" -SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" - -ZERO_ALT = "〇" -ONE_ALT = "幺" -TWO_ALTS = ["两", "兩"] - -POSITIVE = ["正", "正"] -NEGATIVE = ["负", "負"] -POINT = ["点", "點"] -# PLUS = [u'加', u'加'] -# SIL = [u'杠', u'槓'] - -FILLER_CHARS = ["呃", "啊"] - -ER_WHITELIST = ( - "(儿女|儿子|儿孙|女儿|儿媳|妻儿|" - "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|" - "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|" - "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)" -) -ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST) - -# 中文数字系统类型 -NUMBERING_TYPES = ["low", "mid", "high"] - -CURRENCY_NAMES = ( - "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" - "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" -) -CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" -COM_QUANTIFIERS = ( - "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" - "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" - "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" - "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" - "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" - "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)" -) - - -# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) -CN_PUNCS_STOP = "!?。。" -CN_PUNCS_NONSTOP = ( - ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-" -) -CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP - -PUNCS = CN_PUNCS + string.punctuation -PUNCS_TRANSFORM = str.maketrans(PUNCS, " " * len(PUNCS), "") # replace puncs with space - - -# https://zh.wikipedia.org/wiki/全行和半行 -QJ2BJ = { - " ": " ", - "!": "!", - """: '"', - "#": "#", - "$": "$", - "%": "%", - "&": "&", - "'": "'", - "(": "(", - ")": ")", - "*": "*", - "+": "+", - ",": ",", - "-": "-", - ".": ".", - "/": "/", - "0": "0", - "1": "1", - "2": "2", - "3": "3", - "4": "4", - "5": "5", - "6": "6", - "7": "7", - "8": "8", - "9": "9", - ":": ":", - ";": ";", - "<": "<", - "=": "=", - ">": ">", - "?": "?", - "@": "@", - "A": "A", - "B": "B", - "C": "C", - "D": "D", - "E": "E", - "F": "F", - "G": "G", - "H": "H", - "I": "I", - "J": "J", - "K": "K", - "L": "L", - "M": "M", - "N": "N", - "O": "O", - "P": "P", - "Q": "Q", - "R": "R", - "S": "S", - "T": "T", - "U": "U", - "V": "V", - "W": "W", - "X": "X", - "Y": "Y", - "Z": "Z", - "[": "[", - "\": "\\", - "]": "]", - "^": "^", - "_": "_", - "`": "`", - "a": "a", - "b": "b", - "c": "c", - "d": "d", - "e": "e", - "f": "f", - "g": "g", - "h": "h", - "i": "i", - "j": "j", - "k": "k", - "l": "l", - "m": "m", - "n": "n", - "o": "o", - "p": "p", - "q": "q", - "r": "r", - "s": "s", - "t": "t", - "u": "u", - "v": "v", - "w": "w", - "x": "x", - "y": "y", - "z": "z", - "{": "{", - "|": "|", - "}": "}", - "~": "~", -} -QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "") - - -# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources: -# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total -CN_CHARS_COMMON = ( - "一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举" - "乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互" - "亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从" - "仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优" - "伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚" - "佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣" - "侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯" - "俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌" - "偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚" - "僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六" - "兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况" - "冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈" - "刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐" - "剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼" - "劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹" - "区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵" - "卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔" - "叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊" - "同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆" - "呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎" - "咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌" - "响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛" - "唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴" - "啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌" - "嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡" - "嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓" - "嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢" - "圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥" - "坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩" - "垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基" - "埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填" - "塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复" - "夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖" - "套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮" - "妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈" - "娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻" - "婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱" - "嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽" - "宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾" - "宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝" - "尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山" - "屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃" - "峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧" - "崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉" - "巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡" - "带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐" - "庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷" - "建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖" - "彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循" - "徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀" - "态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓" - "恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟" - "悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭" - "惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥" - "慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我" - "戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔" - "托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡" - "抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥" - "拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫" - "振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎" - "掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭" - "揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘" - "摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢" - "擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整" - "敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗" - "旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝" - "星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡" - "晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜" - "曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽" - "杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅" - "枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔" - "柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩" - "株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯" - "桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘" - "棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂" - "楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱" - "榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞" - "橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正" - "此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒" - "毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮" - "氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽" - "汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽" - "沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼" - "泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿" - "流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸" - "浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅" - "淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟" - "渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆" - "溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕" - "滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶" - "漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧" - "澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵" - "灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈" - "烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯" - "焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵" - "熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍" - "牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷" - "犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎" - "猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎" - "玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊" - "珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊" - "琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙" - "瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱" - "璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯" - "田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐" - "疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒" - "痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩" - "瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙" - "皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾" - "省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦" - "睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知" - "矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮" - "砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍" - "碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷" - "磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭" - "祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣" - "秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗" - "穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立" - "竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯" - "笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅" - "箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾" - "簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮" - "粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢" - "縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁" - "绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩" - "绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕" - "编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐" - "网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲" - "羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者" - "耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋" - "职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸" - "肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳" - "胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒" - "腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻" - "臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般" - "舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊" - "芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈" - "苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆" - "茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐" - "荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛" - "莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥" - "菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著" - "葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺" - "蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷" - "蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸" - "薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱" - "虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆" - "蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗" - "蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃" - "螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡" - "蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒" - "袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂" - "褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉" - "觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认" - "讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词" - "诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请" - "诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡" - "谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹" - "豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼" - "贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤" - "赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑" - "跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮" - "踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇" - "躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较" - "辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄" - "迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆" - "选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒" - "道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱" - "邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴" - "郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝" - "酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭" - "醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒" - "钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼" - "钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨" - "铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐" - "锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹" - "锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣" - "镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼" - "闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶" - "阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈" - "隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳" - "零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰" - "靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵" - "韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额" - "颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰" - "饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥" - "馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑" - "骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高" - "髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾" - "鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨" - "鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓" - "鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶" - "鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟" - "鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄" - "黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷" - "鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃" - "㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡" - "䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽" - "𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯" - "𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟" - "𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟" - "𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟" - "𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓" -) -CN_CHARS_EXT = "吶诶屌囧飚屄" - -CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT -IN_CH_CHARS = {c: True for c in CN_CHARS} - -EN_CHARS = string.ascii_letters + string.digits -IN_EN_CHARS = {c: True for c in EN_CHARS} - -VALID_CHARS = CN_CHARS + EN_CHARS + " " -IN_VALID_CHARS = {c: True for c in VALID_CHARS} - -# ================================================================================ # -# basic class -# ================================================================================ # - - -class ChineseChar(object): - """ - 中文字符 - 每个字符对应简体和繁体, - e.g. 简体 = '负', 繁体 = '負' - 转换时可转换为简体或繁体 - """ - - def __init__(self, simplified, traditional): - self.simplified = simplified - self.traditional = traditional - # self.__repr__ = self.__str__ - - def __str__(self): - return self.simplified or self.traditional or None - - def __repr__(self): - return self.__str__() - - -class ChineseNumberUnit(ChineseChar): - """ - 中文数字/数位字符 - 每个字符除繁简体外还有一个额外的大写字符 - e.g. '陆' 和 '陸' - """ - - def __init__(self, power, simplified, traditional, big_s, big_t): - super(ChineseNumberUnit, self).__init__(simplified, traditional) - self.power = power - self.big_s = big_s - self.big_t = big_t - - def __str__(self): - return "10^{}".format(self.power) - - @classmethod - def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): - - if small_unit: - return ChineseNumberUnit( - power=index + 1, - simplified=value[0], - traditional=value[1], - big_s=value[1], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[0]: - return ChineseNumberUnit( - power=index + 8, - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[1]: - return ChineseNumberUnit( - power=(index + 2) * 4, - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[2]: - return ChineseNumberUnit( - power=pow(2, index + 3), - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - else: - raise ValueError( - "Counting type should be in {0} ({1} provided).".format( - NUMBERING_TYPES, numbering_type - ) - ) - - -class ChineseNumberDigit(ChineseChar): - """ - 中文数字字符 - """ - - def __init__( - self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None - ): - super(ChineseNumberDigit, self).__init__(simplified, traditional) - self.value = value - self.big_s = big_s - self.big_t = big_t - self.alt_s = alt_s - self.alt_t = alt_t - - def __str__(self): - return str(self.value) - - @classmethod - def create(cls, i, v): - return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) - - -class ChineseMath(ChineseChar): - """ - 中文数位字符 - """ - - def __init__(self, simplified, traditional, symbol, expression=None): - super(ChineseMath, self).__init__(simplified, traditional) - self.symbol = symbol - self.expression = expression - self.big_s = simplified - self.big_t = traditional - - -CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath - - -class NumberSystem(object): - """ - 中文数字系统 - """ - - pass - - -class MathSymbol(object): - """ - 用于中文数字系统的数学符号 (繁/简体), e.g. - positive = ['正', '正'] - negative = ['负', '負'] - point = ['点', '點'] - """ - - def __init__(self, positive, negative, point): - self.positive = positive - self.negative = negative - self.point = point - - def __iter__(self): - for v in self.__dict__.values(): - yield v - - -# class OtherSymbol(object): -# """ -# 其他符号 -# """ -# -# def __init__(self, sil): -# self.sil = sil -# -# def __iter__(self): -# for v in self.__dict__.values(): -# yield v - - -# ================================================================================ # -# basic utils -# ================================================================================ # -def create_system(numbering_type=NUMBERING_TYPES[1]): - """ - 根据数字系统类型返回创建相应的数字系统,默认为 mid - NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 - low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. - mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. - high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. - 返回对应的数字系统 - """ - - # chinese number units of '亿' and larger - all_larger_units = zip( - LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, - LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, - ) - larger_units = [ - CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) - ] - # chinese number units of '十, 百, 千, 万' - all_smaller_units = zip( - SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, - SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, - ) - smaller_units = [ - CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) - ] - # digis - chinese_digis = zip( - CHINESE_DIGIS, - CHINESE_DIGIS, - BIG_CHINESE_DIGIS_SIMPLIFIED, - BIG_CHINESE_DIGIS_TRADITIONAL, - ) - digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] - digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT - digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT - digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] - - # symbols - positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) - negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) - point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) - # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) - system = NumberSystem() - system.units = smaller_units + larger_units - system.digits = digits - system.math = MathSymbol(positive_cn, negative_cn, point_cn) - # system.symbols = OtherSymbol(sil_cn) - return system - - -def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): - def get_symbol(char, system): - for u in system.units: - if char in [u.traditional, u.simplified, u.big_s, u.big_t]: - return u - for d in system.digits: - if char in [ - d.traditional, - d.simplified, - d.big_s, - d.big_t, - d.alt_s, - d.alt_t, - ]: - return d - for m in system.math: - if char in [m.traditional, m.simplified]: - return m - - def string2symbols(chinese_string, system): - int_string, dec_string = chinese_string, "" - for p in [system.math.point.simplified, system.math.point.traditional]: - if p in chinese_string: - int_string, dec_string = chinese_string.split(p) - break - return [get_symbol(c, system) for c in int_string], [ - get_symbol(c, system) for c in dec_string - ] - - def correct_symbols(integer_symbols, system): - """ - 一百八 to 一百八十 - 一亿一千三百万 to 一亿 一千万 三百万 - """ - - if integer_symbols and isinstance(integer_symbols[0], CNU): - if integer_symbols[0].power == 1: - integer_symbols = [system.digits[1]] + integer_symbols - - if len(integer_symbols) > 1: - if isinstance(integer_symbols[-1], CND) and isinstance( - integer_symbols[-2], CNU - ): - integer_symbols.append( - CNU(integer_symbols[-2].power - 1, None, None, None, None) - ) - - result = [] - unit_count = 0 - for s in integer_symbols: - if isinstance(s, CND): - result.append(s) - unit_count = 0 - elif isinstance(s, CNU): - current_unit = CNU(s.power, None, None, None, None) - unit_count += 1 - - if unit_count == 1: - result.append(current_unit) - elif unit_count > 1: - for i in range(len(result)): - if ( - isinstance(result[-i - 1], CNU) - and result[-i - 1].power < current_unit.power - ): - result[-i - 1] = CNU( - result[-i - 1].power + current_unit.power, - None, - None, - None, - None, - ) - return result - - def compute_value(integer_symbols): - """ - Compute the value. - When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. - e.g. '两千万' = 2000 * 10000 not 2000 + 10000 - """ - value = [0] - last_power = 0 - for s in integer_symbols: - if isinstance(s, CND): - value[-1] = s.value - elif isinstance(s, CNU): - value[-1] *= pow(10, s.power) - if s.power > last_power: - value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) - last_power = s.power - value.append(0) - return sum(value) - - system = create_system(numbering_type) - int_part, dec_part = string2symbols(chinese_string, system) - int_part = correct_symbols(int_part, system) - int_str = str(compute_value(int_part)) - dec_str = "".join([str(d.value) for d in dec_part]) - if dec_part: - return "{0}.{1}".format(int_str, dec_str) - else: - return int_str - - -def num2chn( - number_string, - numbering_type=NUMBERING_TYPES[1], - big=False, - traditional=False, - alt_zero=False, - alt_one=False, - alt_two=True, - use_zeros=True, - use_units=True, -): - def get_value(value_string, use_zeros=True): - - striped_string = value_string.lstrip("0") - - # record nothing if all zeros - if not striped_string: - return [] - - # record one digits - elif len(striped_string) == 1: - if use_zeros and len(value_string) != len(striped_string): - return [system.digits[0], system.digits[int(striped_string)]] - else: - return [system.digits[int(striped_string)]] - - # recursively record multiple digits - else: - result_unit = next( - u for u in reversed(system.units) if u.power < len(striped_string) - ) - result_string = value_string[: -result_unit.power] - return ( - get_value(result_string) - + [result_unit] - + get_value(striped_string[-result_unit.power :]) - ) - - system = create_system(numbering_type) - - int_dec = number_string.split(".") - if len(int_dec) == 1: - int_string = int_dec[0] - dec_string = "" - elif len(int_dec) == 2: - int_string = int_dec[0] - dec_string = int_dec[1] - else: - raise ValueError( - "invalid input num string with more than one dot: {}".format(number_string) - ) - - if use_units and len(int_string) > 1: - result_symbols = get_value(int_string) - else: - result_symbols = [system.digits[int(c)] for c in int_string] - dec_symbols = [system.digits[int(c)] for c in dec_string] - if dec_string: - result_symbols += [system.math.point] + dec_symbols - - if alt_two: - liang = CND( - 2, - system.digits[2].alt_s, - system.digits[2].alt_t, - system.digits[2].big_s, - system.digits[2].big_t, - ) - for i, v in enumerate(result_symbols): - if isinstance(v, CND) and v.value == 2: - next_symbol = ( - result_symbols[i + 1] if i < len(result_symbols) - 1 else None - ) - previous_symbol = result_symbols[i - 1] if i > 0 else None - if isinstance(next_symbol, CNU) and isinstance( - previous_symbol, (CNU, type(None)) - ): - if next_symbol.power != 1 and ( - (previous_symbol is None) or (previous_symbol.power != 1) - ): - result_symbols[i] = liang - - # if big is True, '两' will not be used and `alt_two` has no impact on output - if big: - attr_name = "big_" - if traditional: - attr_name += "t" - else: - attr_name += "s" - else: - if traditional: - attr_name = "traditional" - else: - attr_name = "simplified" - - result = "".join([getattr(s, attr_name) for s in result_symbols]) - - # if not use_zeros: - # result = result.strip(getattr(system.digits[0], attr_name)) - - if alt_zero: - result = result.replace( - getattr(system.digits[0], attr_name), system.digits[0].alt_s - ) - - if alt_one: - result = result.replace( - getattr(system.digits[1], attr_name), system.digits[1].alt_s - ) - - for i, p in enumerate(POINT): - if result.startswith(p): - return CHINESE_DIGIS[0] + result - - # ^10, 11, .., 19 - if ( - len(result) >= 2 - and result[1] - in [ - SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], - SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], - ] - and result[0] - in [ - CHINESE_DIGIS[1], - BIG_CHINESE_DIGIS_SIMPLIFIED[1], - BIG_CHINESE_DIGIS_TRADITIONAL[1], - ] - ): - result = result[1:] - - return result - - -# ================================================================================ # -# different types of rewriters -# ================================================================================ # -class Cardinal: - """ - CARDINAL类 - """ - - def __init__(self, cardinal=None, chntext=None): - self.cardinal = cardinal - self.chntext = chntext - - def chntext2cardinal(self): - return chn2num(self.chntext) - - def cardinal2chntext(self): - return num2chn(self.cardinal) - - -class Digit: - """ - DIGIT类 - """ - - def __init__(self, digit=None, chntext=None): - self.digit = digit - self.chntext = chntext - - # def chntext2digit(self): - # return chn2num(self.chntext) - - def digit2chntext(self): - return num2chn(self.digit, alt_two=False, use_units=False) - - -class TelePhone: - """ - TELEPHONE类 - """ - - def __init__(self, telephone=None, raw_chntext=None, chntext=None): - self.telephone = telephone - self.raw_chntext = raw_chntext - self.chntext = chntext - - # def chntext2telephone(self): - # sil_parts = self.raw_chntext.split('') - # self.telephone = '-'.join([ - # str(chn2num(p)) for p in sil_parts - # ]) - # return self.telephone - - def telephone2chntext(self, fixed=False): - - if fixed: - sil_parts = self.telephone.split("-") - self.raw_chntext = "".join( - [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] - ) - self.chntext = self.raw_chntext.replace("", "") - else: - sp_parts = self.telephone.strip("+").split() - self.raw_chntext = "".join( - [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] - ) - self.chntext = self.raw_chntext.replace("", "") - return self.chntext - - -class Fraction: - """ - FRACTION类 - """ - - def __init__(self, fraction=None, chntext=None): - self.fraction = fraction - self.chntext = chntext - - def chntext2fraction(self): - denominator, numerator = self.chntext.split("分之") - return chn2num(numerator) + "/" + chn2num(denominator) - - def fraction2chntext(self): - numerator, denominator = self.fraction.split("/") - return num2chn(denominator) + "分之" + num2chn(numerator) - - -class Date: - """ - DATE类 - """ - - def __init__(self, date=None, chntext=None): - self.date = date - self.chntext = chntext - - # def chntext2date(self): - # chntext = self.chntext - # try: - # year, other = chntext.strip().split('年', maxsplit=1) - # year = Digit(chntext=year).digit2chntext() + '年' - # except ValueError: - # other = chntext - # year = '' - # if other: - # try: - # month, day = other.strip().split('月', maxsplit=1) - # month = Cardinal(chntext=month).chntext2cardinal() + '月' - # except ValueError: - # day = chntext - # month = '' - # if day: - # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] - # else: - # month = '' - # day = '' - # date = year + month + day - # self.date = date - # return self.date - - def date2chntext(self): - date = self.date - try: - year, other = date.strip().split("年", 1) - year = Digit(digit=year).digit2chntext() + "年" - except ValueError: - other = date - year = "" - if other: - try: - month, day = other.strip().split("月", 1) - month = Cardinal(cardinal=month).cardinal2chntext() + "月" - except ValueError: - day = date - month = "" - if day: - day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] - else: - month = "" - day = "" - chntext = year + month + day - self.chntext = chntext - return self.chntext - - -class Money: - """ - MONEY类 - """ - - def __init__(self, money=None, chntext=None): - self.money = money - self.chntext = chntext - - # def chntext2money(self): - # return self.money - - def money2chntext(self): - money = self.money - pattern = re.compile(r"(\d+(\.\d+)?)") - matchers = pattern.findall(money) - if matchers: - for matcher in matchers: - money = money.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() - ) - self.chntext = money - return self.chntext - - -class Percentage: - """ - PERCENTAGE类 - """ - - def __init__(self, percentage=None, chntext=None): - self.percentage = percentage - self.chntext = chntext - - def chntext2percentage(self): - return chn2num(self.chntext.strip().strip("百分之")) + "%" - - def percentage2chntext(self): - return "百分之" + num2chn(self.percentage.strip().strip("%")) - - -def normalize_nsw(raw_text): - text = "^" + raw_text + "$" - - # 规范化日期 - pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") - matchers = pattern.findall(text) - if matchers: - # print('date') - for matcher in matchers: - text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) - - # 规范化金钱 - pattern = re.compile( - r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)" - ) - matchers = pattern.findall(text) - if matchers: - # print('money') - for matcher in matchers: - text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) - - # 规范化固话/手机号码 - # 手机 - # http://www.jihaoba.com/news/show/13680 - # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 - # 联通:130、131、132、156、155、186、185、176 - # 电信:133、153、189、180、181、177 - pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") - matchers = pattern.findall(text) - if matchers: - # print('telephone') - for matcher in matchers: - text = text.replace( - matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 - ) - # 固话 - pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") - matchers = pattern.findall(text) - if matchers: - # print('fixed telephone') - for matcher in matchers: - text = text.replace( - matcher[0], - TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), - 1, - ) - - # 规范化分数 - pattern = re.compile(r"(\d+/\d+)") - matchers = pattern.findall(text) - if matchers: - # print('fraction') - for matcher in matchers: - text = text.replace( - matcher, Fraction(fraction=matcher).fraction2chntext(), 1 - ) - - # 规范化百分数 - text = text.replace("%", "%") - pattern = re.compile(r"(\d+(\.\d+)?%)") - matchers = pattern.findall(text) - if matchers: - # print('percentage') - for matcher in matchers: - text = text.replace( - matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1 - ) - - # 规范化纯数+量词 - pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) - matchers = pattern.findall(text) - if matchers: - # print('cardinal+quantifier') - for matcher in matchers: - text = text.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 - ) - - # 规范化数字编号 - pattern = re.compile(r"(\d{4,32})") - matchers = pattern.findall(text) - if matchers: - # print('digit') - for matcher in matchers: - text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) - - # 规范化纯数 - pattern = re.compile(r"(\d+(\.\d+)?)") - matchers = pattern.findall(text) - if matchers: - # print('cardinal') - for matcher in matchers: - text = text.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 - ) - - # restore P2P, O2O, B2C, B2B etc - pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") - matchers = pattern.findall(text) - if matchers: - # print('particular') - for matcher in matchers: - text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) - - return text.lstrip("^").rstrip("$") - - -def remove_erhua(text): - """ - 去除儿化音词中的儿: - 他女儿在那边儿 -> 他女儿在那边 - """ - - new_str = "" - while re.search("儿", text): - a = re.search("儿", text).span() - remove_er_flag = 0 - - if ER_WHITELIST_PATTERN.search(text): - b = ER_WHITELIST_PATTERN.search(text).span() - if b[0] <= a[0]: - remove_er_flag = 1 - - if remove_er_flag == 0: - new_str = new_str + text[0 : a[0]] - text = text[a[1] :] - else: - new_str = new_str + text[0 : b[1]] - text = text[b[1] :] - - text = new_str + text - return text - - -def remove_space(text): - tokens = text.split() - new = [] - for k, t in enumerate(tokens): - if k != 0: - if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]): - new.append(" ") - new.append(t) - return "".join(new) - - -class TextNorm: - def __init__( - self, - to_banjiao: bool = True, - to_upper: bool = True, - to_lower: bool = False, - remove_fillers: bool = True, - remove_erhua: bool = True, - check_chars: bool = False, - remove_space: bool = False, - cc_mode: str = "", - ): - self.to_banjiao = to_banjiao - self.to_upper = to_upper - self.to_lower = to_lower - self.remove_fillers = remove_fillers - self.remove_erhua = remove_erhua - self.check_chars = check_chars - self.remove_space = remove_space - - self.cc = None - if cc_mode: - from opencc import OpenCC # Open Chinese Convert: pip install opencc - - self.cc = OpenCC(cc_mode) - - def __call__(self, text): - if self.cc: - text = self.cc.convert(text) - - if self.to_banjiao: - text = text.translate(QJ2BJ_TRANSFORM) - - if self.to_upper: - text = text.upper() - - if self.to_lower: - text = text.lower() - - if self.remove_fillers: - for c in FILLER_CHARS: - text = text.replace(c, "") - - if self.remove_erhua: - text = remove_erhua(text) - - text = normalize_nsw(text) - - text = text.translate(PUNCS_TRANSFORM) - - if self.check_chars: - for c in text: - if not IN_VALID_CHARS.get(c): - print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) - return "" - - if self.remove_space: - text = remove_space(text) - - return text - - -if __name__ == "__main__": - p = argparse.ArgumentParser() - - # normalizer options - p.add_argument( - "--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao" - ) - p.add_argument("--to_upper", action="store_true", help="convert to upper case") - p.add_argument("--to_lower", action="store_true", help="convert to lower case") - p.add_argument( - "--remove_fillers", - action="store_true", - help='remove filler chars such as "呃, 啊"', - ) - p.add_argument( - "--remove_erhua", - action="store_true", - help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"', - ) - p.add_argument( - "--check_chars", - action="store_true", - help="skip sentences containing illegal chars", - ) - p.add_argument("--remove_space", action="store_true", help="remove whitespace") - p.add_argument( - "--cc_mode", - choices=["", "t2s", "s2t"], - default="", - help="convert between traditional to simplified", - ) - - # I/O options - p.add_argument( - "--log_interval", - type=int, - default=10000, - help="log interval in number of processed lines", - ) - p.add_argument( - "--has_key", - action="store_true", - help="will be deprecated, set --format ark instead", - ) - p.add_argument( - "--format", - type=str, - choices=["txt", "ark", "tsv"], - default="txt", - help="input format", - ) - p.add_argument("ifile", help="input filename, assume utf-8 encoding") - p.add_argument("ofile", help="output filename") - - args = p.parse_args() - - if args.has_key: - args.format = "ark" - - normalizer = TextNorm( - to_banjiao=args.to_banjiao, - to_upper=args.to_upper, - to_lower=args.to_lower, - remove_fillers=args.remove_fillers, - remove_erhua=args.remove_erhua, - check_chars=args.check_chars, - remove_space=args.remove_space, - cc_mode=args.cc_mode, - ) - - ndone = 0 - with open(args.ifile, "r", encoding="utf8") as istream, open( - args.ofile, "w+", encoding="utf8" - ) as ostream: - if args.format == "tsv": - reader = csv.DictReader(istream, delimiter="\t") - assert "TEXT" in reader.fieldnames - print("\t".join(reader.fieldnames), file=ostream) - - for item in reader: - text = item["TEXT"] - - if text: - text = normalizer(text) - - if text: - item["TEXT"] = text - print("\t".join([item[f] for f in reader.fieldnames]), file=ostream) - - ndone += 1 - if ndone % args.log_interval == 0: - print( - f"text norm: {ndone} lines done.", file=sys.stderr, flush=True - ) - else: - for line in istream: - key, text = "", "" - if args.format == "ark": # KALDI archive, line format: "key text" - cols = line.strip().split(maxsplit=1) - key, text = cols[0], cols[1] if len(cols) == 2 else "" - else: - text = line.strip() - - if text: - text = normalizer(text) - - if text: - if args.format == "ark": - print(key + "\t" + text, file=ostream) - else: - print(text, file=ostream) - - ndone += 1 - if ndone % args.log_interval == 0: - print( - f"text norm: {ndone} lines done.", file=sys.stderr, flush=True - ) - print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True) diff --git a/egs/speechio/ASR/prepare.sh b/egs/speechio/ASR/prepare.sh deleted file mode 100644 index 048a66d8f..000000000 --- a/egs/speechio/ASR/prepare.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -stage=3 -stop_stage=3 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/SPEECHIO_ASR_ZH00000 -# This directory contains the following files downloaded from -# https://github.com/SpeechColab/Leaderboard -# -# - metadata.tsv -# - wav -# - wav.scp -# - trans.txt -# - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare speechio manifest" - # We assume that you have downloaded the speechio dataset - # to $dl_dir - mkdir -p data/manifests - if [ ! -e data/manifests/.speechio.done ]; then - lhotse prepare speechio $dl_dir data/manifests - touch data/manifests/.speechio.done - fi -fi - -whisper_mel_bins=80 -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute whisper fbank for speechio" - if [ ! -f data/fbank/.speechio.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_speechio.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.speechio.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute kaldi fbank for speechio" - if [ ! -f data/fbank/.speechio.kaldi.done ]; then - fbank_dir=data/fbank_kaldi - mkdir -p $fbank_dir - ./local/compute_fbank_speechio.py --fbank-dir $fbank_dir - touch data/fbank/.speechio.kaldi.done - fi -fi diff --git a/egs/speechio/ASR/shared b/egs/speechio/ASR/shared deleted file mode 120000 index 9d8803a7d..000000000 --- a/egs/speechio/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared// \ No newline at end of file diff --git a/egs/speechio/ASR/whisper/asr_datamodule.py b/egs/speechio/ASR/whisper/asr_datamodule.py deleted file mode 100644 index 7382fd3f5..000000000 --- a/egs/speechio/ASR/whisper/asr_datamodule.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class AsrDataModule: - """ - DataModule for k2 ASR experiments. - There is no train and valid dataloader, for speechio dataset - but there can be multiple test dataloaders. - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=300.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - parser.add_argument( - "--start-index", - type=int, - default=0, - help="Decoding will start from dataset SPEECHIO_ASR_ZH000index", - ) - - parser.add_argument( - "--end-index", - type=int, - default=26, - help="Decoding will end with dataset SPEECHIO_ASR_ZH000index", - ) - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py deleted file mode 100644 index c20f1f714..000000000 --- a/egs/speechio/ASR/whisper/decode.py +++ /dev/null @@ -1,530 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Wei Kang) -# 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -# Command for decoding using fine-tuned models: -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper -ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch 999 --avg 1 \ - --beam-size 10 --max-duration 50 - -# Command for decoding using pretrained models (before fine-tuning): - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2_pretrained \ - --model-name large-v2 \ - --epoch -1 --avg 1 \ - --start-index 14 --end-index 15 \ - --remove-whisper-encoder-input-length-restriction False \ - --beam-size 1 --max-duration 50 - -""" - -import argparse -import logging -import re -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -import whisper -from asr_datamodule import AsrDataModule -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from tn.chinese.normalizer import Normalizer -from whisper.normalizers import BasicTextNormalizer -from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from zhconv import convert - -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def average_checkpoints( - filenames: List[Path], device: torch.device = torch.device("cpu") -) -> dict: - """Average a list of checkpoints. - The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. - - Args: - filenames: - Filenames of the checkpoints to be averaged. We assume all - checkpoints are saved by :func:`save_checkpoint`. - device: - Move checkpoints to this device before averaging. - Returns: - Return a dict (i.e., state_dict) which is the average of all - model state dicts contained in the checkpoints. - """ - n = len(filenames) - - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] - else: - avg = torch.load(filenames[0], map_location=device) - - # Identify shared parameters. Two parameters are said to be shared - # if they have the same data_ptr - uniqued: Dict[int, str] = dict() - - for k, v in avg.items(): - v_data_ptr = v.data_ptr() - if v_data_ptr in uniqued: - continue - uniqued[v_data_ptr] = k - - uniqued_names = list(uniqued.values()) - - for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] - else: - state_dict = torch.load(filenames[i], map_location=device) - for k in uniqued_names: - avg[k] += state_dict[k] - - for k in uniqued_names: - if avg[k].is_floating_point(): - avg[k] /= n - else: - avg[k] //= n - - return avg - - -def remove_punctuation(text: str or List[str]): - """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py - - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings without any punctuation. - """ - punctuation = "!,.;:?、!,。;:?《》 " - if isinstance(text, str): - text = re.sub(r"[{}]+".format(punctuation), "", text).strip() - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = re.sub(r"[{}]+".format(punctuation), "", t).strip() - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type {type(text)}") - - -def to_simple(text: str or List[str]): - """Convert traditional Chinese to simplified Chinese. - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings converted to simplified Chinese. - """ - if isinstance(text, str): - text = convert(text, "zh-cn") - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = convert(t, "zh-cn") - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type{type(text)}") - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=-1, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="beam-search", - help="""Decoding method. - Supported values are: - - beam-search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=1, - help="beam size for beam search decoding", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="whisper/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], - help="""The model name to use. - """, - ) - - parser.add_argument( - "--remove-whisper-encoder-input-length-restriction", - type=str2bool, - default=True, - help="replace whisper encoder forward method to remove input length restriction", - ) - - parser.add_argument( - "--use-distill-whisper", - type=str2bool, - default=False, - help="Whether to use architecture of distill whisper.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: "beam-search" - - value: A list of lists. Each sublist is a list of token IDs. - Args: - params: - It is returned by :func:`get_params`. - model: - The neural model. - batch: - It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. - Returns: - Return a dict, whose key may be "beam-search". - """ - dtype = torch.float16 - device = torch.device("cuda") - - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device, dtype=dtype).transpose(1, 2) - if not params.remove_whisper_encoder_input_length_restriction: - T = 3000 - if feature.shape[2] < T: - feature = torch.cat( - [ - feature, - torch.zeros( - feature.shape[0], feature.shape[1], T - feature.shape[2] - ).to(device, dtype=dtype), - ], - 2, - ) - - supervisions = batch["supervisions"] - feature_len = supervisions["num_frames"] - feature_len = feature_len.to(device, dtype=dtype) - results = model.decode(feature, params.decoding_options) - hyps = [result.text for result in results] - - hyps = remove_punctuation(hyps) - hyps = to_simple(hyps) - hyps = [params.normalizer.normalize(hyp) for hyp in hyps] - print(hyps) - return {"beam-search": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - The dataloader. - params: - It is returned by :func:`get_params`. - model: - The neural model. - Returns: - Return a dict, whose key may be "beam-search". - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - batch=batch, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" - ) - - options = whisper.DecodingOptions( - task="transcribe", - language="zh", - without_timestamps=True, - beam_size=params.beam_size, - ) - params.decoding_options = options - params.cleaner = BasicTextNormalizer() - params.normalizer = Normalizer() - - logging.info("Decoding started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda") - - logging.info(f"device: {device}") - - if params.remove_whisper_encoder_input_length_restriction: - replace_whisper_encoder_forward() - if params.use_distill_whisper: - replace_whisper_decoder_forward() - model = whisper.load_model(params.model_name, "cpu") - if params.epoch > 0: - if params.avg > 1: - start = params.epoch - params.avg - assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - # deepspeed converted checkpoint only contains model state_dict - filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" - for epoch in range(start, params.epoch + 1) - ] - model.load_state_dict(average_checkpoints(filenames)) - else: - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - # save checkpoints - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(model.state_dict(), filename) - else: - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - model.load_state_dict(checkpoint, strict=True) - else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index) - - def remove_long_utt(c: Cut): - # Keep only utterances with duration in 30 seconds - # - if c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - test_sets_cuts = multi_dataset.test_cuts() - - test_sets = test_sets_cuts.keys() - test_dls = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/speechio/ASR/whisper/multi_dataset.py b/egs/speechio/ASR/whisper/multi_dataset.py deleted file mode 100644 index f55d45394..000000000 --- a/egs/speechio/ASR/whisper/multi_dataset.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import glob -import logging -import re -from pathlib import Path -from typing import Dict, List - -import lhotse -from lhotse import CutSet, load_manifest_lazy - - -class MultiDataset: - def __init__(self, fbank_dir: str, start_index: int = 0, end_index: int = 26): - """ - Args: - manifest_dir: - It is expected to contain the following files: - - speechio_cuts_SPEECHIO_ASR_ZH00000.jsonl.gz - ... - - speechio_cuts_SPEECHIO_ASR_ZH00026.jsonl.gz - """ - self.fbank_dir = Path(fbank_dir) - self.start_index = start_index - self.end_index = end_index - - def test_cuts(self) -> Dict[str, CutSet]: - logging.info("About to get multidataset test cuts") - - dataset_parts = [] - for i in range(self.start_index, self.end_index + 1): - idx = f"{i}".zfill(2) - dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") - - prefix = "speechio" - suffix = "jsonl.gz" - - results_dict = {} - for partition in dataset_parts: - path = f"{prefix}_cuts_{partition}.{suffix}" - - logging.info(f"Loading {path} set in lazy mode") - test_cuts = load_manifest_lazy(self.fbank_dir / path) - results_dict[partition] = test_cuts - - return results_dict diff --git a/egs/speechio/ASR/whisper/requirements.txt b/egs/speechio/ASR/whisper/requirements.txt deleted file mode 120000 index 744bf8bb6..000000000 --- a/egs/speechio/ASR/whisper/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/requirements.txt \ No newline at end of file diff --git a/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py b/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py deleted file mode 120000 index 167fba1eb..000000000 --- a/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py +++ /dev/null @@ -1 +0,0 @@ -../../../multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py \ No newline at end of file diff --git a/egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py deleted file mode 120000 index 2a7808921..000000000 --- a/egs/speechio/ASR/whisper/whisper_encoder_forward_monkey_patch.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/asr_datamodule.py b/egs/speechio/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index bf446dabe..000000000 --- a/egs/speechio/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../whisper/asr_datamodule.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/beam_search.py b/egs/speechio/ASR/zipformer/beam_search.py deleted file mode 120000 index 8e2c0a65c..000000000 --- a/egs/speechio/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/ctc_decode.py b/egs/speechio/ASR/zipformer/ctc_decode.py deleted file mode 100644 index f9d0db993..000000000 --- a/egs/speechio/ASR/zipformer/ctc_decode.py +++ /dev/null @@ -1,623 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Liyong Guo, -# Quandong Wang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) ctc-decoding -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method ctc-decoding - -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AsrDataModule -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import get_lattice, one_best_decoding -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_2000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="ctc-decoding", - help="""Decoding method. - Supported values are: - - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_decoding_params() -> AttributeDict: - """Parameters for decoding.""" - params = AttributeDict( - { - "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. - - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. - - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. - - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - torch.div( - supervisions["start_frame"], - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - supervisions["num_frames"], - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) - - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.decoding_method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - word_table: - It is the word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - H=H, - bpe_model=bpe_model, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = list(ref_text.replace(" ", "")) - hyp_words = list("".join(hyp_words)) - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - assert params.decoding_method in ("ctc-decoding",) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - params.vocab_size = num_classes - # and are defined in local/train_bpe_model.py - params.blank_id = 0 - - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=True, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - - G = None - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index) - - test_sets_cuts = multi_dataset.test_cuts() - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets = test_sets_cuts.keys() - test_dl = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dl): - logging.info(f"Start decoding test set: {test_set}") - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/speechio/ASR/zipformer/decode.py b/egs/speechio/ASR/zipformer/decode.py deleted file mode 100644 index ffdd7b500..000000000 --- a/egs/speechio/ASR/zipformer/decode.py +++ /dev/null @@ -1,843 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import AsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from multi_dataset import MultiDataset -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_2000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - hyp_text = "".join(hyp_words) - this_batch.append((cut_id, ref_text, hyp_text)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - params.suffix += f"-blank-penalty-{params.blank_penalty}" - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir, args.start_index, args.end_index) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" - ) - return T > 0 - - test_sets_cuts = multi_dataset.test_cuts() - - test_sets = test_sets_cuts.keys() - test_dl = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) - for cuts_name in test_sets - ] - - for test_set, test_dl in zip(test_sets, test_dl): - logging.info(f"Start decoding test set: {test_set}") - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/speechio/ASR/zipformer/decoder.py b/egs/speechio/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/speechio/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/encoder_interface.py b/egs/speechio/ASR/zipformer/encoder_interface.py deleted file mode 120000 index c2eaca671..000000000 --- a/egs/speechio/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/joiner.py b/egs/speechio/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/speechio/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/model.py b/egs/speechio/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/speechio/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/multi_dataset.py b/egs/speechio/ASR/zipformer/multi_dataset.py deleted file mode 120000 index af164667a..000000000 --- a/egs/speechio/ASR/zipformer/multi_dataset.py +++ /dev/null @@ -1 +0,0 @@ -../whisper/multi_dataset.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/optim.py b/egs/speechio/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/speechio/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/scaling.py b/egs/speechio/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/speechio/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/scaling_converter.py b/egs/speechio/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/speechio/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/subsampling.py b/egs/speechio/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/speechio/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/train.py b/egs/speechio/ASR/zipformer/train.py deleted file mode 120000 index ad7216cf7..000000000 --- a/egs/speechio/ASR/zipformer/train.py +++ /dev/null @@ -1 +0,0 @@ -../../../multi_zh-hans/ASR/zipformer/train.py \ No newline at end of file diff --git a/egs/speechio/ASR/zipformer/zipformer.py b/egs/speechio/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/speechio/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/README.md b/egs/spgispeech/ASR/README.md deleted file mode 100644 index f60408cc1..000000000 --- a/egs/spgispeech/ASR/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# SPGISpeech - -SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective -transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in -length to allow easy training for speech recognition systems. Calls represent a broad -cross-section of international business English; SPGISpeech contains approximately 50,000 -speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and -L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio. - -Transcription text represents the output of several stages of manual post-processing. -As such, the text contains polished English orthography following a detailed style guide, -including proper casing, punctuation, and denormalized non-standard words such as numbers -and acronyms, making SPGISpeech suited for training fully formatted end-to-end models. - -Official reference: - -O’Neill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam, -J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G. -(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted -end-to-end speech recognition. ArXiv, abs/2104.02014. - -ArXiv link: https://arxiv.org/abs/2104.02014 - -## Performance Record - -| Decoding method | val WER | val CER | -|---------------------------|------------|---------| -| greedy search | 2.40 | 0.99 | -| modified beam search | 2.24 | 0.91 | -| fast beam search | 2.35 | 0.97 | - -See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details. diff --git a/egs/spgispeech/ASR/RESULTS.md b/egs/spgispeech/ASR/RESULTS.md deleted file mode 100644 index f2da53193..000000000 --- a/egs/spgispeech/ASR/RESULTS.md +++ /dev/null @@ -1,138 +0,0 @@ -## Results - -### SPGISpeech BPE training results (Zipformer Transducer) - -#### 2024-01-05 - -#### Zipformer encoder + embedding decoder - -Transducer: Zipformer encoder + stateless decoder. - -The WERs are: - -| | dev | val | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 2.08 | 2.14 | --epoch 30 --avg 10 | -| modified beam search | 2.05 | 2.09 | --epoch 30 --avg 10 --beam-size 4 | -| fast beam search | 2.07 | 2.17 | --epoch 30 --avg 10 --beam 20 --max-contexts 8 --max-states 64 | - -**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the -transcripts are orthographic or normalized. These WERs correspond to the normalized transcription -scenario. - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -python zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --num-workers 2 \ - --max-duration 1000 -``` - -The decoding command is: -``` -# greedy search -python ./zipformer/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./zipformer/exp \ - --max-duration 1000 \ - --decoding-method greedy_search - -# modified beam search -python ./zipformer/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./zipformer/exp \ - --max-duration 1000 \ - --decoding-method modified_beam_search - -# fast beam search -python ./zipformer/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./zipformer/exp \ - --max-duration 1000 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -### SPGISpeech BPE training results (Pruned Transducer) - -#### 2022-05-11 - -#### Conformer encoder + embedding decoder - -Conformer encoder + non-current decoder. The decoder -contains only an embedding layer, a Conv1d (with kernel size 2) and a linear -layer (to transform tensor dim). - -The WERs are - -| | dev | val | comment | -|---------------------------|------------|------------|------------------------------------------| -| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 | -| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 | -| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | - -**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the -transcripts are orthographic or normalized. These WERs correspond to the normalized transcription -scenario. - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless2/train.py \ - --world-size 8 \ - --num-epochs 20 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --max-duration 200 \ - --prune-range 5 \ - --lr-factor 5 \ - --lm-scale 0.25 \ - --use-fp16 True -``` - -The decoding command is: -``` -# greedy search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method greedy_search - -# modified beam search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -# fast beam search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -Pretrained model is available at - -The tensorboard training log can be found at - diff --git a/egs/spgispeech/ASR/local/__init__.py b/egs/spgispeech/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/spgispeech/ASR/local/compile_hlg.py b/egs/spgispeech/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/spgispeech/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py deleted file mode 100755 index 9bea28a41..000000000 --- a/egs/spgispeech/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -from pathlib import Path - -import torch -from lhotse import CutSet, LilcomChunkyWriter, combine -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - dataset_parts = ( - "music", - "speech", - "noise", - ) - manifests = read_manifests_if_cached( - prefix="musan", dataset_parts=dataset_parts, output_dir=src_dir - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = src_dir / "cuts_musan.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) - .compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / "feats_musan", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - ) - - logging.info(f"Saving to {musan_cuts_path}") - musan_cuts.to_file(musan_cuts_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py deleted file mode 100755 index 20ff6d7ab..000000000 --- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the SPGISpeech dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" -import argparse -import logging -from pathlib import Path - -import torch -from lhotse import LilcomChunkyWriter, load_manifest_lazy -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-splits", - type=int, - default=20, - help="Number of splits for the train set.", - ) - parser.add_argument( - "--start", - type=int, - default=0, - help="Start index of the train set split.", - ) - parser.add_argument( - "--stop", - type=int, - default=-1, - help="Stop index of the train set split.", - ) - parser.add_argument( - "--test", - action="store_true", - help="If set, only compute features for the dev and val set.", - ) - parser.add_argument( - "--train", - action="store_true", - help="If set, only compute features for the train set.", - ) - - return parser.parse_args() - - -def compute_fbank_spgispeech(args): - assert args.train or args.test, "Either train or test must be set." - - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - if args.train: - logging.info("Processing train") - cut_set = load_manifest_lazy(src_dir / "cuts_train_raw.jsonl.gz") - chunk_size = len(cut_set) // args.num_splits - cut_sets = cut_set.split_lazy( - output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}", - chunk_size=chunk_size, - ) - start = args.start - stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits - num_digits = len(str(args.num_splits)) - for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) - cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz" - logging.info(f"Processing train split {i}") - cs = cut_sets[i].compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_train_{idx}", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - cs.to_file(cuts_train_idx_path) - - if args.test: - for partition in ["dev", "val"]: - if (output_dir / f"cuts_{partition}.jsonl.gz").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / f"feats_{partition}", - manifest_path=src_dir / f"cuts_{partition}.jsonl.gz", - batch_duration=500, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_spgispeech(args) diff --git a/egs/spgispeech/ASR/local/prepare_lang.py b/egs/spgispeech/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/spgispeech/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/local/prepare_lang_bpe.py b/egs/spgispeech/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/spgispeech/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py deleted file mode 100755 index 508d4acd8..000000000 --- a/egs/spgispeech/ASR/local/prepare_splits.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file splits the training set into train and dev sets. -""" -import logging -from pathlib import Path - -import torch -from lhotse import CutSet -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def split_spgispeech_train(): - src_dir = Path("data/manifests") - - manifests = read_manifests_if_cached( - dataset_parts=["train", "val"], - output_dir=src_dir, - prefix="spgispeech", - suffix="jsonl.gz", - lazy=True, - ) - assert manifests is not None - - train_dev_cuts = CutSet.from_manifests( - recordings=manifests["train"]["recordings"], - supervisions=manifests["train"]["supervisions"], - ) - dev_cuts = train_dev_cuts.subset(first=4000) - train_cuts = train_dev_cuts.filter(lambda c: c not in dev_cuts) - - # Add speed perturbation - train_cuts = ( - train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1) - ) - - # Write the manifests to disk. - train_cuts.to_file(src_dir / "cuts_train_raw.jsonl.gz") - dev_cuts.to_file(src_dir / "cuts_dev_raw.jsonl.gz") - - # Also write the val set to disk. - val_cuts = CutSet.from_manifests( - recordings=manifests["val"]["recordings"], - supervisions=manifests["val"]["supervisions"], - ) - val_cuts.to_file(src_dir / "cuts_val_raw.jsonl.gz") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - split_spgispeech_train() diff --git a/egs/spgispeech/ASR/local/train_bpe_model.py b/egs/spgispeech/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/spgispeech/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh deleted file mode 100755 index 8331f94d5..000000000 --- a/egs/spgispeech/ASR/prepare.sh +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=20 -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/spgispeech -# You can find train.csv, val.csv, train, and val in this directory, which belong -# to the SPGISpeech dataset. -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/spgispeech, - # you can create a symlink - # - # ln -sfv /path/to/spgispeech $dl_dir/spgispeech - # - if [ ! -d $dl_dir/spgispeech/train.csv ]; then - lhotse download spgispeech $dl_dir - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare SPGISpeech manifest (may take ~1h)" - # We assume that you have downloaded the SPGISpeech corpus - # to $dl_dir/spgispeech. We perform text normalization for the transcripts. - mkdir -p data/manifests - lhotse prepare spgispeech -j $nj --normalize-text $dl_dir/spgispeech data/manifests -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - lhotse combine data/manifests/recordings_{music,speech,noise}.json data/manifests/recordings_musan.jsonl.gz - lhotse cut simple -r data/manifests/recordings_musan.jsonl.gz data/manifests/cuts_musan_raw.jsonl.gz -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Split train into train and dev and create cut sets." - python local/prepare_splits.py -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank features for spgispeech dev and val" - mkdir -p data/fbank - python local/compute_fbank_spgispeech.py --test -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank features for train" - mkdir -p data/fbank - python local/compute_fbank_spgispeech.py --train --num-splits 20 - - log "Combine features from train splits (may take ~1h)" - if [ ! -f data/manifests/cuts_train.jsonl.gz ]; then - pieces=$(find data/manifests -name "cuts_train_[0-9]*.jsonl.gz") - lhotse combine $pieces data/manifests/cuts_train.jsonl.gz - fi - gunzip -c data/manifests/cuts_train.jsonl.gz | shuf | gzip -c > data/manifests/cuts_train_shuf.jsonl.gz -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Compute fbank features for musan" - mkdir -p data/fbank - python local/compute_fbank_musan.py -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Dump transcripts for LM training" - mkdir -p data/lm - gunzip -c data/manifests/cuts_train_raw.jsonl.gz \ - | jq '.supervisions[0].text' \ - | sed 's:"::g' \ - > data/lm/transcript_words.txt -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - - # Add special words to words.txt - echo " 0" > $lang_dir/words.txt - echo "!SIL 1" >> $lang_dir/words.txt - echo " 2" >> $lang_dir/words.txt - - # Add regular words to words.txt - gunzip -c data/manifests/cuts_train_raw.jsonl.gz \ - | jq '.supervisions[0].text' \ - | sed 's:"::g' \ - | sed 's: :\n:g' \ - | sort \ - | uniq \ - | sed '/^$/d' \ - | awk '{print $0,NR+2}' \ - >> $lang_dir/words.txt - - # Add remaining special word symbols expected by LM scripts. - num_words=$(cat $lang_dir/words.txt | wc -l) - echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(cat $lang_dir/words.txt | wc -l) - echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(cat $lang_dir/words.txt | wc -l) - echo "#0 ${num_words}" >> $lang_dir/words.txt - - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript data/lm/transcript_words.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - fi - done -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Train LM" - lm_dir=data/lm - - if [ ! -f $lm_dir/G.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text $lm_dir/transcript_words.txt \ - -lm $lm_dir/G.arpa - fi - - if [ ! -f $lm_dir/G_3_gram.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $lm_dir/G.arpa > $lm_dir/G_3_gram.fst.txt - fi -fi - -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - done -fi diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py deleted file mode 100644 index 75c5385a7..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader -from tqdm import tqdm - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class SPGISpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=False, - help="When enabled, the last batch will be dropped", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--max-duration", - type=int, - default=100.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - if self.args.on_the_fly_feats: - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - else: - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=False, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get SPGISpeech train cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get SPGISpeech dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") - - @lru_cache() - def val_cuts(self) -> CutSet: - logging.info("About to get SPGISpeech val cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_val.jsonl.gz") - - -def test(): - parser = argparse.ArgumentParser() - SPGISpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - adm = SPGISpeechAsrDataModule(args) - - cuts = adm.train_cuts() - dl = adm.train_dataloaders(cuts) - for i, batch in tqdm(enumerate(dl)): - if i == 100: - break - - cuts = adm.dev_cuts() - dl = adm.valid_dataloaders(cuts) - for i, batch in tqdm(enumerate(dl)): - if i == 100: - break - - -if __name__ == "__main__": - test() diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/conformer.py deleted file mode 120000 index a65957180..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py deleted file mode 100755 index 4434aae62..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1,566 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) beam search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 \ - --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import SPGISpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - test_set_cers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" - with open(wers_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - # we also compute CER for spgispeech dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" - with open(cers_filename, "w") as f: - cer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True - ) - test_set_cers[key] = cer - - logging.info("Wrote detailed error stats to {}".format(wers_filename)) - - test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} - test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER\tCER", file=f) - for key in test_set_wers: - print( - "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]), - file=f, - ) - - s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key in test_set_wers: - s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - SPGISpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - spgispeech = SPGISpeechAsrDataModule(args) - - dev_cuts = spgispeech.dev_cuts() - val_cuts = spgispeech.val_cuts() - - dev_dl = spgispeech.test_dataloaders(dev_cuts) - val_dl = spgispeech.test_dataloaders(val_cuts) - - test_sets = ["dev", "val"] - test_dl = [dev_dl, val_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/encoder_interface.py deleted file mode 120000 index f58253127..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py deleted file mode 100755 index 68763808a..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --avg-last-n 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless2/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/spgispeech/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/model.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py deleted file mode 100755 index a9146a0fe..000000000 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ /dev/null @@ -1,1003 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --use_fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 550 - -""" - - -import argparse -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import SPGISpeechAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=4, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - if params.print_diagnostics and batch_idx == 30: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - spgispeech = SPGISpeechAsrDataModule(args) - - train_cuts = spgispeech.train_cuts() - - # Ideally we should filter utterances that are too long or too short, - # but SPGISpeech contains regular length utterances so we don't need to - # do that. Here are the statistics of the training data (obtained by - # `train_cuts.describe()`): - - # Cuts count: 5886320 - # Total duration (hours): 15070.1 - # Speech duration (hours): 15070.1 (100.0%) - # *** - # Duration statistics (seconds): - # mean 9.2 - # std 2.8 - # min 4.6 - # 25% 6.9 - # 50% 8.9 - # 75% 11.2 - # 99% 16.0 - # 99.5% 16.3 - # 99.9% 16.6 - # max 16.7 - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = spgispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = spgispeech.dev_cuts() - valid_dl = spgispeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - - -def main(): - parser = get_parser() - SPGISpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/spgispeech/ASR/shared b/egs/spgispeech/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/spgispeech/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/asr_datamodule.py b/egs/spgispeech/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/spgispeech/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/beam_search.py b/egs/spgispeech/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/spgispeech/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/decode.py b/egs/spgispeech/ASR/zipformer/decode.py deleted file mode 100755 index 90d318919..000000000 --- a/egs/spgispeech/ASR/zipformer/decode.py +++ /dev/null @@ -1,1052 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import SPGISpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - modified_beam_search_LODR - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding-method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding-method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding-method is modified_beam_search and - modified_beam_search_LODR. - """, - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - context_graph=context_graph, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"{prefix}_{key}"] = hyps - return ans - else: - if params.has_contexts: - prefix += f"-context-score-{params.context_score}" - return {prefix: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - SPGISpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - if "modified_beam_search" in params.decoding_method: - if os.path.exists(params.context_file): - contexts = [] - for line in open(params.context_file).readlines(): - contexts.append(line.strip()) - context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - spgispeech = SPGISpeechAsrDataModule(args) - - dev_cuts = spgispeech.dev_cuts() - val_cuts = spgispeech.val_cuts() - - dev_dl = spgispeech.test_dataloaders(dev_cuts) - val_dl = spgispeech.test_dataloaders(val_cuts) - - test_sets = ["dev", "val"] - test_dl = [dev_dl, val_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/spgispeech/ASR/zipformer/decoder.py b/egs/spgispeech/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/spgispeech/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/encoder_interface.py b/egs/spgispeech/ASR/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/spgispeech/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/joiner.py b/egs/spgispeech/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/spgispeech/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/model.py b/egs/spgispeech/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/spgispeech/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/optim.py b/egs/spgispeech/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/spgispeech/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/pretrained.py b/egs/spgispeech/ASR/zipformer/pretrained.py deleted file mode 100755 index a562fb9f6..000000000 --- a/egs/spgispeech/ASR/zipformer/pretrained.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -Note: This is a example for spgispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - fast_beam_search_one_best, - greedy_search_batch, - modified_beam_search, -) -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.utils import make_pad_mask - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/spgispeech/ASR/zipformer/scaling.py b/egs/spgispeech/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/spgispeech/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/scaling_converter.py b/egs/spgispeech/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/spgispeech/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/subsampling.py b/egs/spgispeech/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/spgispeech/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py deleted file mode 100755 index dfc21c968..000000000 --- a/egs/spgispeech/ASR/zipformer/train.py +++ /dev/null @@ -1,1364 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import SPGISpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - spgispeech = SPGISpeechAsrDataModule(args) - - train_cuts = spgispeech.train_cuts() - - # Ideally we should filter utterances that are too long or too short, - # but SPGISpeech contains regular length utterances so we don't need to - # do that. Here are the statistics of the training data (obtained by - # `train_cuts.describe()`): - - # Cuts count: 5886320 - # Total duration (hours): 15070.1 - # Speech duration (hours): 15070.1 (100.0%) - # *** - # Duration statistics (seconds): - # mean 9.2 - # std 2.8 - # min 4.6 - # 25% 6.9 - # 50% 8.9 - # 75% 11.2 - # 99% 16.0 - # 99.5% 16.3 - # 99.9% 16.6 - # max 16.7 - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = spgispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = spgispeech.dev_cuts() - valid_dl = spgispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - SPGISpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/spgispeech/ASR/zipformer/zipformer.py b/egs/spgispeech/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/spgispeech/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore deleted file mode 100644 index 11d674922..000000000 --- a/egs/swbd/ASR/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -switchboard_word_alignments.tar.gz -./swb_ms98_transcriptions/ diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md deleted file mode 100644 index 13b27815a..000000000 --- a/egs/swbd/ASR/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# Switchboard - -The Switchboard-1 Telephone Speech Corpus (LDC97S62) consists of approximately 260 hours of speech and was originally collected by Texas Instruments in 1990-1, under DARPA sponsorship. The first release of the corpus was published by NIST and distributed by the LDC in 1992-3. Since that release, a number of corrections have been made to the data files as presented on the original CD-ROM set and all copies of the first pressing have been distributed. - -Switchboard is a collection of about 2,400 two-sided telephone conversations among 543 speakers (302 male, 241 female) from all areas of the United States. A computer-driven robot operator system handled the calls, giving the caller appropriate recorded prompts, selecting and dialing another person (the callee) to take part in a conversation, introducing a topic for discussion and recording the speech from the two subjects into separate channels until the conversation was finished. About 70 topics were provided, of which about 50 were used frequently. Selection of topics and callees was constrained so that: (1) no two speakers would converse together more than once and (2) no one spoke more than once on a given topic. - -(The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) - - -## Performance Record -| | eval2000 | rt03 | -|--------------------------------|------------|--------| -| `conformer_ctc` | 33.37 | 35.06 | - -See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. - -## Credit - -The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. - -A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. - -Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). - -The `sclite_scoring.py` is from the GigaSpeech recipe for post processing and glm-like scoring, which is definitely not an elegant stuff to do. diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md deleted file mode 100644 index f3a22c444..000000000 --- a/egs/swbd/ASR/RESULTS.md +++ /dev/null @@ -1,113 +0,0 @@ -## Results -### Switchboard BPE training results (Conformer-CTC) - -#### 2023-09-04 - -The best WER, as of 2023-09-04, for the Switchboard is below - -Results using attention decoder are given as: - -| | eval2000-swbd | eval2000-callhome | eval2000-avg | -|--------------------------------|-----------------|---------------------|--------------| -| `conformer_ctc` | 9.48 | 17.73 | 13.67 | - -Decoding results and models can be found here: -https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 -#### 2023-06-27 - -The best WER, as of 2023-06-27, for the Switchboard is below - -Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: - -| | eval2000 | rt03 | -|--------------------------------|------------|--------| -| `conformer_ctc` | 30.80 | 32.29 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: - -##### eval2000 - -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.9 | 1.1 | - -##### rt03 - -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.9 | 1.9 | - -To reproduce the above result, use the following commands for training: - -```bash -cd egs/swbd/ASR -./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1" -./conformer_ctc/train.py \ - --max-duration 120 \ - --num-workers 8 \ - --enable-musan False \ - --world-size 2 \ - --num-epochs 100 -``` - -and the following command for decoding: - -```bash -./conformer_ctc/decode.py \ - --epoch 99 \ - --avg 10 \ - --max-duration 50 -``` - -#### 2023-06-26 - -The best WER, as of 2023-06-26, for the Switchboard is below - -Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: - -| | eval2000 | rt03 | -|--------------------------------|------------|--------| -| `conformer_ctc` | 33.37 | 35.06 | - -Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: - -##### eval2000 - -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.3 | 2.5 | - -##### rt03 - -| ngram_lm_scale | attention_scale | -|----------------|-----------------| -| 0.7 | 1.3 | - -To reproduce the above result, use the following commands for training: - -```bash -cd egs/swbd/ASR -./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1" -./conformer_ctc/train.py \ - --max-duration 120 \ - --num-workers 8 \ - --enable-musan False \ - --world-size 2 \ -``` - -and the following command for decoding: - -```bash -./conformer_ctc/decode.py \ - --epoch 55 \ - --avg 1 \ - --max-duration 50 -``` - -For your reference, the nbest oracle WERs are: - -| | eval2000 | rt03 | -|--------------------------------|------------|--------| -| `conformer_ctc` | 25.64 | 26.84 | diff --git a/egs/swbd/ASR/conformer_ctc/__init__.py b/egs/swbd/ASR/conformer_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py deleted file mode 100644 index 0f6f02e8d..000000000 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ /dev/null @@ -1,417 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# Modified by Zengrui Jin for the SwitchBoard corpus -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class SwitchBoardAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train dataloader, - but there can be multiple test dataloaders (e.g. SwitchBoard rt03 - and eval2000). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=50, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_all_cuts(self) -> CutSet: - logging.info("SwitchBoard: About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" - ).subset(last=166844) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("SwitchBoard: About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" - ).subset(first=300) - - @lru_cache() - def test_eval2000_cuts(self) -> CutSet: - logging.info("SwitchBoard: About to get eval2000 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "eval2000" / "eval2000_cuts_all.jsonl.gz" - ) - - @lru_cache() - def test_rt03_cuts(self) -> CutSet: - logging.info("SwitchBoard: About to get rt03 cuts") - return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_rt03.jsonl.gz") diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py deleted file mode 120000 index d1f4209d7..000000000 --- a/egs/swbd/ASR/conformer_ctc/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py deleted file mode 100755 index 44b2e95d6..000000000 --- a/egs/swbd/ASR/conformer_ctc/export.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from conformer import Conformer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=98, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=55, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt.", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=True, - help="""True to save a model after applying torch.jit.script. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 4, - "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, - "num_decoder_layers": 6, - } - ) - return params - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - logging.info(params) - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=params.vocab_size, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/swbd/ASR/conformer_ctc/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py deleted file mode 120000 index 526bc9678..000000000 --- a/egs/swbd/ASR/conformer_ctc/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py deleted file mode 100755 index 0383c4d71..000000000 --- a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Jiayu Du -# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import os - -conversational_filler = [ - "UH", - "UHH", - "UM", - "EH", - "MM", - "HM", - "AH", - "HUH", - "HA", - "ER", - "OOF", - "HEE", - "ACH", - "EEE", - "EW", - "MHM", - "HUM", - "AW", - "OH", - "HMM", - "UMM", -] -unk_tags = ["", ""] -switchboard_garbage_utterance_tags = [ - "[LAUGHTER]", - "[NOISE]", - "[VOCALIZED-NOISE]", - "[SILENCE]", -] -non_scoring_words = ( - conversational_filler + unk_tags + switchboard_garbage_utterance_tags -) - - -def asr_text_post_processing(text: str) -> str: - # 1. convert to uppercase - text = text.upper() - - # 2. remove non-scoring words from evaluation - remaining_words = [] - text_split = text.split() - word_to_skip = 0 - for idx, word in enumerate(text_split): - if word_to_skip > 0: - word_to_skip -= 1 - continue - if word in non_scoring_words: - continue - elif word == "CANCELLED": - remaining_words.append("CANCELED") - continue - elif word == "AIRFLOW": - remaining_words.append("AIR") - remaining_words.append("FLOW") - continue - elif word == "PHD": - remaining_words.append("P") - remaining_words.append("H") - remaining_words.append("D") - continue - elif word == "UCLA": - remaining_words.append("U") - remaining_words.append("C") - remaining_words.append("L") - remaining_words.append("A") - continue - elif word == "ONTO": - remaining_words.append("ON") - remaining_words.append("TO") - continue - elif word == "DAY": - try: - if text_split[idx + 1] == "CARE": - remaining_words.append("DAYCARE") - word_to_skip = 1 - except: - remaining_words.append(word) - continue - remaining_words.append(word) - - return " ".join(remaining_words) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result via" - "SCTK's tool sclite" - ) - parser.add_argument( - "ref", - type=str, - help="sclite's standard transcription(trn) reference file", - ) - parser.add_argument( - "hyp", - type=str, - help="sclite's standard transcription(trn) hypothesis file", - ) - parser.add_argument( - "work_dir", - type=str, - help="working dir", - ) - args = parser.parse_args() - - if not os.path.isdir(args.work_dir): - os.mkdir(args.work_dir) - - REF = os.path.join(args.work_dir, "REF") - HYP = os.path.join(args.work_dir, "HYP") - RESULT = os.path.join(args.work_dir, "RESULT") - - for io in [(args.ref, REF), (args.hyp, HYP)]: - with open(io[0], "r", encoding="utf8") as fi: - with open(io[1], "w+", encoding="utf8") as fo: - for line in fi: - line = line.strip() - if line: - cols = line.split() - text = asr_text_post_processing(" ".join(cols[0:-1])) - uttid_field = cols[-1] - print(f"{text} {uttid_field}", file=fo) - - # GigaSpeech's uttid comforms to swb - os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py deleted file mode 120000 index 16354dc73..000000000 --- a/egs/swbd/ASR/conformer_ctc/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py deleted file mode 100755 index 5d4438fd1..000000000 --- a/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from distutils.version import LooseVersion - -import torch -from label_smoothing import LabelSmoothingLoss - -torch_ver = LooseVersion(torch.__version__) - - -def test_with_torch_label_smoothing_loss(): - if torch_ver < LooseVersion("1.10.0"): - print(f"Current torch version: {torch_ver}") - print("Please use torch >= 1.10 to run this test - skipping") - return - torch.manual_seed(20211105) - x = torch.rand(20, 30, 5000) - tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) - for reduction in ["none", "sum", "mean"]: - custom_loss_func = LabelSmoothingLoss( - ignore_index=-1, label_smoothing=0.1, reduction=reduction - ) - custom_loss = custom_loss_func(x, tgt) - - torch_loss_func = torch.nn.CrossEntropyLoss( - ignore_index=-1, reduction=reduction, label_smoothing=0.1 - ) - torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) - assert torch.allclose(custom_loss, torch_loss) - - -def main(): - test_with_torch_label_smoothing_loss() - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/conformer_ctc/test_subsampling.py b/egs/swbd/ASR/conformer_ctc/test_subsampling.py deleted file mode 100755 index 81fa234dd..000000000 --- a/egs/swbd/ASR/conformer_ctc/test_subsampling.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from subsampling import Conv2dSubsampling, VggSubsampling - - -def test_conv2d_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = Conv2dSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim - - -def test_vgg_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = VggSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py deleted file mode 120000 index 8b0990ec6..000000000 --- a/egs/swbd/ASR/conformer_ctc/test_transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py deleted file mode 100755 index 7f1eebbcf..000000000 --- a/egs/swbd/ASR/conformer_ctc/train.py +++ /dev/null @@ -1,814 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# Modified by Zengrui Jin for the SwitchBoard corpus -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - export CUDA_VISIBLE_DEVICES="0,1,2,3" - ./conformer_ctc/train.py \ - --exp-dir ./conformer_ctc/exp \ - --world-size 4 \ - --max-duration 200 \ - --num-epochs 20 -""" - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import SwitchBoardAsrDataModule -from conformer import Conformer -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=98, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - conformer_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.8, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--num-decoder-layers", - type=int, - default=6, - help="""Number of decoder layer of transformer decoder. - Setting this to 0 will not create the decoder at all (pure CTC model) - """, - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - use_feat_batchnorm: Normalization for the input features, can be a - boolean indicating whether to do batch - normalization, or a float which means just scaling - the input features with this float value. - If given a float value, we will remove batchnorm - layer in `ConvolutionModule` as well. - - - attention_dim: Hidden dim for multi-head attention model. - - - head: Number of heads of multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - weight_decay: The weight_decay for the optimizer. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, - # parameters for loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for Noam - "weight_decay": 1e-6, - "warm_step": 80000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: BpeCtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): - # Works with a BPE model - token_ids = graph_compiler.texts_to_ids(texts) - decoding_graph = graph_compiler.compile(token_ids) - elif isinstance(graph_compiler, CtcTrainingGraphCompiler): - # Works with a phone lexicon - decoding_graph = graph_compiler.compile(texts) - else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - info["loss"] = loss.detach().cpu().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = supervisions["num_frames"].sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() - ) - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: BpeCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - if "lang_bpe" in str(params.lang_dir): - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - elif "lang_phone" in str(params.lang_dir): - assert params.att_rate == 0, ( - "Attention decoder training does not support phone lang dirs " - "at this time due to a missing symbol. Set --att-rate=0 " - "for pure CTC training when using a phone-based lang dir." - ) - assert params.num_decoder_layers == 0, ( - "Attention decoder training does not support phone lang dirs " - "at this time due to a missing symbol. " - "Set --num-decoder-layers=0 for pure CTC training when using " - "a phone-based lang dir." - ) - graph_compiler = CtcTrainingGraphCompiler( - lexicon, - device=device, - ) - # Manually add the sos/eos ID with their default values - # from the BPE recipe which we're adapting here. - graph_compiler.sos_id = 1 - graph_compiler.eos_id = 1 - else: - raise ValueError( - f"Unsupported type of lang dir (we expected it to have " - f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - switchboard = SwitchBoardAsrDataModule(args) - - train_cuts = switchboard.train_all_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_dl = switchboard.train_dataloaders(train_cuts) - - valid_cuts = switchboard.dev_cuts() - valid_dl = switchboard.valid_dataloaders(valid_cuts) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - SwitchBoardAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py deleted file mode 120000 index 1c3f43fcf..000000000 --- a/egs/swbd/ASR/conformer_ctc/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/swbd/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/swbd/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py deleted file mode 100755 index d446e8ff3..000000000 --- a/egs/swbd/ASR/local/compute_fbank_eval2000.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the SwitchBoard dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import sentencepiece as spm -import torch -from filter_cuts import filter_cuts -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to the bpe.model. If not None, we will remove short and - long utterances before extracting features""", - ) - - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - ) - - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", - ) - - return parser.parse_args() - - -def compute_fbank_switchboard( - dir_name: str, - bpe_model: Optional[str] = None, - dataset: Optional[str] = None, - perturb_speed: Optional[bool] = True, -): - src_dir = Path(f"data/manifests/{dir_name}") - output_dir = Path(f"data/fbank/{dir_name}") - num_jobs = min(1, os.cpu_count()) - num_mel_bins = 80 - - if bpe_model: - logging.info(f"Loading {bpe_model}") - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - if dataset is None: - dataset_parts = ("all",) - else: - dataset_parts = dataset.split(" ", -1) - - prefix = dir_name - suffix = "jsonl.gz" - manifests = { - "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", - } - assert manifests is not None - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) - - with get_executor() as ex: # Initialize the executor only once. - partition = "all" - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - print(cuts_filename) - if (output_dir / cuts_filename).is_file(): - logging.info(f"{prefix} already exists - skipping.") - return - logging.info(f"Processing {prefix}") - cut_set = CutSet.from_file(manifests[prefix]).resample(16000) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - compute_fbank_switchboard( - dir_name="eval2000", - bpe_model=args.bpe_model, - dataset=args.dataset, - perturb_speed=args.perturb_speed, - ) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py deleted file mode 100755 index dd82220c0..000000000 --- a/egs/swbd/ASR/local/compute_fbank_swbd.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the SwitchBoard dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import sentencepiece as spm -import torch -from filter_cuts import filter_cuts -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to the bpe.model. If not None, we will remove short and - long utterances before extracting features""", - ) - - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - ) - - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", - ) - - parser.add_argument( - "--split-index", - type=int, - required=True, - ) - - return parser.parse_args() - - -def compute_fbank_switchboard( - dir_name: str, - split_index: int, - bpe_model: Optional[str] = None, - dataset: Optional[str] = None, - perturb_speed: Optional[bool] = True, -): - src_dir = Path(f"data/manifests/{dir_name}") - output_dir = Path(f"data/fbank/{dir_name}_split16") - num_jobs = min(1, os.cpu_count()) - num_mel_bins = 80 - - if bpe_model: - logging.info(f"Loading {bpe_model}") - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - if dataset is None: - dataset_parts = ("all",) - else: - dataset_parts = dataset.split(" ", -1) - - prefix = dir_name - suffix = "jsonl.gz" - split_dir = Path("data/manifests/swbd_split16/") - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) - - with get_executor() as ex: # Initialize the executor only once. - partition = "all" - cuts_filename = ( - f"{prefix}_cuts_{partition}.{str(split_index).zfill(2)}.{suffix}" - ) - print(cuts_filename) - if (output_dir / cuts_filename).is_file(): - logging.info(f"{prefix} already exists - skipping.") - return - logging.info(f"Processing {prefix}") - cut_set = ( - CutSet.from_file( - split_dir - / f"swbd_train_all_trimmed.{str(split_index).zfill(2)}.jsonl.gz" - ) - .resample(16000) - .to_eager() - .filter(lambda c: c.duration > 2.0) - ) - - if bpe_model: - cut_set = filter_cuts(cut_set, sp) - if perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}_{str(split_index).zfill(2)}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - min_duration=None, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - - compute_fbank_switchboard( - dir_name="swbd", - split_index=args.split_index, - bpe_model=args.bpe_model, - dataset=args.dataset, - perturb_speed=args.perturb_speed, - ) diff --git a/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 100755 index a8d5117c9..000000000 --- a/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -""" -Convert a transcript file containing words to a corpus file containing tokens -for LM training with the help of a lexicon. - -If the lexicon contains phones, the resulting LM will be a phone LM; If the -lexicon contains word pieces, the resulting LM will be a word piece LM. - -If a word has multiple pronunciations, the one that appears first in the lexicon -is kept; others are removed. - -If the input transcript is: - - hello zoo world hello - world zoo - foo zoo world hellO - -and if the lexicon is - - SPN - hello h e l l o 2 - hello h e l l o - world w o r l d - zoo z o o - -Then the output is - - h e l l o 2 z o o w o r l d h e l l o 2 - w o r l d z o o - SPN z o o w o r l d SPN -""" - -import argparse -from pathlib import Path -from typing import Dict, List - -from generate_unique_lexicon import filter_multiple_pronunications - -from icefall.lexicon import read_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--transcript", - type=str, - help="The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words.", - ) - parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument("--oov", type=str, default="", help="The OOV word.") - - return parser.parse_args() - - -def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: - """ - Args: - lexicon: - A dict containing pronunciations. Its keys are words and values - are pronunciations (i.e., tokens). - line: - A line of transcript consisting of space(s) separated words. - oov_token: - The pronunciation of the oov word if a word in `line` is not present - in the lexicon. - Returns: - Return None. - """ - s = "" - words = line.strip().split() - for i, w in enumerate(words): - tokens = lexicon.get(w, oov_token) - s += " ".join(tokens) - s += " " - print(s.strip()) - - -def main(): - args = get_args() - assert Path(args.lexicon).is_file() - assert Path(args.transcript).is_file() - assert len(args.oov) > 0 - - # Only the first pronunciation of a word is kept - lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) - - lexicon = dict(lexicon) - - assert args.oov in lexicon - - oov_token = lexicon[args.oov] - - with open(args.transcript) as f: - for line in f: - process_line(lexicon=lexicon, line=line, oov_token=oov_token) - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/local/dict.patch b/egs/swbd/ASR/local/dict.patch deleted file mode 100644 index 12c63d612..000000000 --- a/egs/swbd/ASR/local/dict.patch +++ /dev/null @@ -1,380 +0,0 @@ -1d0 -< file: $SWB/data/dictionary/sw-ms98-dict.text -8645a8646 -> uh-hum ah m hh ah m -9006c9007 -< April ey p r ih l ---- -> April ey p r ax l -9144d9144 -< B ay zh aa n iy z -9261c9261 -< Battle b ae t el ---- -> Battle b ae t ax l -10014a10015 -> Chevy sh eh v iy -10211a10213 -> Colorado k ao l ax r aa d ow -10212a10215 -> Colorado' k ao l ax r aa d ow z -10370c10373 -< Creek k r ih k ---- -> Creek k r iy k -10889a10893 -> Eleven ax l eh v ih n -10951c10955 -< Erie ih r iy ---- -> Erie iy r iy -11183c11187 -< Forever f ax r eh v er ---- -> Forever f er eh v er -11231a11236 -> Friday f r ay d iy -11744a11750 -> History hh ih s t r iy -12004a12011,12012 -> Israel ih z r ih l -> Israel's ih z r ih l z -12573a12582 -> Lincoln l ih ng k ih n -12574a12584 -> Lincolns l ih ng k ih n z -13268c13278 -< NAACP eh ey ey s iy p iy ---- -> NAACP eh n ey ey s iy p iy -13286c13296 -< NIT eh ay t iy ---- -> NIT eh n ay t iy -13292c13302 -< NTSC eh t iy eh s s iy ---- -> NTSC eh n t iy eh s s iy -14058a14069 -> Quarter k ow r t er -14059a14071 -> Quarterback k ow r t er b ae k -14060a14073 -> Quarters k ow r t er z -14569a14583 -> Science s ay n s -15087a15102 -> Sunday s ah n d iy -15088a15104 -> Sunday's s ah n d iy z -15089a15106 -> Sundays s ah n d iy z -15290,15291c15307,15308 -< Texan t eh k sh ih n -< Texan's t eh k sh ih n s ---- -> Texan t eh k s ih n -> Texan's t eh k s ih n s -15335a15353 -> Thousands th aw z ih n z -15739c15757 -< Waco w ae k ow ---- -> Waco w ey k ow -15841a15860 -> Weekends w iy k eh n z -16782a16802 -> acceptable eh k s eh p ax b ax l -16833a16854 -> accounting ax k aw n ih ng -16948a16970 -> address ax d r eh s -17281a17304 -> already aa r d iy -17315a17339 -> am m -17709a17734 -> asked ae s t -17847a17873 -> attorney ih t er n iy -17919a17946 -> autopilot ao t ow p ay l ih t -17960a17988 -> awfully ao f l iy -18221a18250 -> basketball b ae s k ax b ao l -18222a18252 -> basketball's b ae s k ax b ao l z -18302a18333 -> become b ah k ah m -18303a18335 -> becomes b iy k ah m z -18344a18377 -> began b ax g en n -18817c18850 -< bottle b aa t el ---- -> bottle b aa t ax l -19332,19333c19365,19367 -< camera's k ae m ax r ax z -< cameras k ae m ax r ax z ---- -> camera k ae m r ax -> camera's k ae m r ax z -> cameras k ae m r ax z -19411a19446 -> capital k ae p ax l -19505a19541 -> carrying k ae r ih ng -20316a20353,20354 -> combination k aa m ih n ey sh ih n -> combinations k aa m ih n ey sh ih n z -20831a20870 -> contracts k aa n t r ae k s -21010a21050 -> costs k ao s -21062a21103 -> county k aw n iy -21371a21413 -> cultural k ao l ch ax r ax l -21372a21415 -> culturally k ao l ch ax r ax l iy -21373a21417 -> culture k ao l ch er -21375a21420 -> cultures k ao l ch er z -21543a21589 -> data d ey t ax -22097a22144 -> differently d ih f ax r ih n t l iy -22972a23020 -> effects ax f eh k t s -23016a23065 -> election ax l eh k sh ih n -23018a23068 -> elections ax l eh k sh ih n z -23052a23103 -> eleven ax l eh v ih n -23242a23294 -> enjoyable ae n jh oy ax b ax l -23248a23301 -> enjoys ae n jh oy z -23293a23347 -> entire ih n t ay r -23295a23350,23351 -> entirely ih n t ay r l iy -> entirety ih n t ay r t iy -23745a23802 -> extra eh k s t er -23818a23876 -> facts f ae k s -24508c24566 -< forever f ax r eh v er ---- -> forever f er eh v er -24514c24572 -< forget f ow r g eh t ---- -> forget f er r g eh t -24521a24580 -> forgot f er r g aa t -24522a24582 -> forgotten f er r g aa t ax n -24563a24624 -> forward f ow er d -24680a24742 -> frightening f r ay t n ih ng -24742a24805 -> full-time f ax l t ay m -24862a24926 -> garage g r aa jh -25218a25283 -> grandmother g r ae m ah dh er -25790a25856 -> heavily hh eh v ax l iy -25949a26016 -> history hh ih s t r iy -26038a26106 -> honestly aa n ax s t l iy -26039a26108 -> honesty aa n ax s t iy -26099a26169 -> horror hh ow r -26155a26226 -> houses hh aw z ih z -26184c26255 -< huh-uh hh ah hh ah ---- -> huh-uh ah hh ah -26189c26260 -< hum-um hh m hh m ---- -> hum-um ah m hh ah m -26236a26308 -> hunting hh ah n ih ng -26307a26380,26381 -> ideal ay d iy l -> idealist ay d iy l ih s t -26369a26444 -> imagine m ae jh ih n -26628a26704 -> individuals ih n d ih v ih jh ax l z -26968a27045 -> interest ih n t r ih s t -27184a27262 -> it'd ih d -27702a27781 -> lead l iy d -28378a28458 -> mandatory m ae n d ih t ow r iy -28885a28966 -> minute m ih n ih t -29167a29249 -> mountains m aw t n z -29317a29400 -> mysteries m ih s t r iy z -29318a29402 -> mystery m ih s t r iy -29470a29555 -> nervous n er v ih s -29578,29580c29663,29665 -< nobody n ow b aa d iy -< nobody'll n ow b aa d iy l -< nobody's n ow b aa d iy z ---- -> nobody n ow b ah d iy -> nobody'll n ow b ah d iy l -> nobody's n ow b ah d iy z -29712a29798 -> nuclear n uw k l iy r -29938a30025 -> onto aa n t ax -30051a30139 -> originally ax r ih jh ax l iy -30507a30596 -> particularly p er t ih k y ax l iy -30755a30845 -> perfectly p er f ih k l iy -30820a30911 -> personally p er s n ax l iy -30915a31007 -> physically f ih z ih k l iy -30986a31079 -> pilot p ay l ih t -30987a31081 -> pilot's p ay l ih t s -31227a31322 -> police p l iy s -31513a31609 -> prefer p er f er -31553a31650 -> prepare p r ax p ey r -31578a31676 -> prescription p er s k r ih p sh ih n -31579a31678 -> prescriptions p er s k r ih p sh ih n z -31770a31870 -> products p r aa d ax k s -31821a31922 -> projects p r aa jh eh k s -31908a32010 -> protect p er t eh k t -31909a32012 -> protected p er t eh k t ih d -31911a32015 -> protection p er t eh k sh ih n -31914a32019 -> protection p er t eh k t ih v -32149a32255 -> quarter k ow r t er -32414a32521 -> read r iy d -32785a32893 -> rehabilitation r iy ax b ih l ih t ey sh ih n -33150a33259 -> resource r ih s ow r s -33151a33261 -> resources r iy s ow r s ih z -33539c33649 -< roots r uh t s ---- -> roots r uw t s -33929a34040 -> science s ay n s -34315a34427 -> seventy s eh v ih n iy -34319,34320c34431,34432 -< severe s ax v iy r -< severely s ax v iy r l iy ---- -> severe s ih v iy r -> severely s ih v iy r l iy -35060a35173 -> software s ao f w ey r -35083a35197 -> solid s ao l ih d -35084a35199 -> solidly s ao l ih d l iy -35750a35866 -> stood s t ih d -35854a35971 -> strictly s t r ih k l iy -35889c36006 -< stronger s t r ao ng er ---- -> stronger s t r ao ng g er -36192a36310,36311 -> supposed s p ow z -> supposed s p ow s -36510a36630 -> tastes t ey s -36856a36977 -> thoroughly th er r l iy -36866a36988 -> thousands th aw z ih n z -37081c37203 -< toots t uh t s ---- -> toots t uw t s -37157a37280 -> toward t w ow r d -37158a37282 -> towards t w ow r d z -37564a37689 -> twenties t w eh n iy z -37565a37691 -> twentieth t w eh n iy ih th -37637a37764 -> unacceptable ah n ae k s eh p ax b ax l -37728a37856 -> understand ah n d er s t ae n -37860a37989 -> unless ih n l eh s -38040a38170 -> use y uw z -38049a38180 -> uses y uw z ih z -38125a38257 -> various v ah r iy ih s -38202a38335 -> versus v er s ih z -38381c38514 -< wacko w ae k ow ---- -> wacko w ey k ow -38455c38588 -< wanna w aa n ax ---- -> wanna w ah n ax -38675c38808 -< whatnot w ah t n aa t ---- -> whatnot w aa t n aa t -38676a38810 -> whatsoever w aa t s ow eh v er -38890c39024 -< wok w aa k ---- -> wok w ao k -38910a39045 -> wondering w ah n d r ih ng diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py deleted file mode 100755 index 9aa204863..000000000 --- a/egs/swbd/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - # path = "./data/fbank/swbd_cuts_rt03.jsonl.gz" - path = "./data/fbank/eval2000/eval2000_cuts_all.jsonl.gz" - # path = "./data/fbank/swbd_cuts_all.jsonl.gz" - - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Training Cut statistics: -╒═══════════════════════════╤═══════════╕ -│ Cuts count: │ 167244 │ -├───────────────────────────┼───────────┤ -│ Total duration (hh:mm:ss) │ 281:01:26 │ -├───────────────────────────┼───────────┤ -│ mean │ 6.0 │ -├───────────────────────────┼───────────┤ -│ std │ 3.3 │ -├───────────────────────────┼───────────┤ -│ min │ 2.0 │ -├───────────────────────────┼───────────┤ -│ 25% │ 3.2 │ -├───────────────────────────┼───────────┤ -│ 50% │ 5.2 │ -├───────────────────────────┼───────────┤ -│ 75% │ 8.3 │ -├───────────────────────────┼───────────┤ -│ 99% │ 14.4 │ -├───────────────────────────┼───────────┤ -│ 99.5% │ 14.7 │ -├───────────────────────────┼───────────┤ -│ 99.9% │ 15.0 │ -├───────────────────────────┼───────────┤ -│ max │ 57.5 │ -├───────────────────────────┼───────────┤ -│ Recordings available: │ 167244 │ -├───────────────────────────┼───────────┤ -│ Features available: │ 167244 │ -├───────────────────────────┼───────────┤ -│ Supervisions available: │ 167244 │ -╘═══════════════════════════╧═══════════╛ -Speech duration statistics: -╒══════════════════════════════╤═══════════╤══════════════════════╕ -│ Total speech duration │ 281:01:26 │ 100.00% of recording │ -├──────────────────────────────┼───────────┼──────────────────────┤ -│ Total speaking time duration │ 281:01:26 │ 100.00% of recording │ -├──────────────────────────────┼───────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧═══════════╧══════════════════════╛ - -Eval2000 Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 4473 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 03:37:13 │ -├───────────────────────────┼──────────┤ -│ mean │ 2.9 │ -├───────────────────────────┼──────────┤ -│ std │ 2.6 │ -├───────────────────────────┼──────────┤ -│ min │ 0.1 │ -├───────────────────────────┼──────────┤ -│ 25% │ 1.2 │ -├───────────────────────────┼──────────┤ -│ 50% │ 2.1 │ -├───────────────────────────┼──────────┤ -│ 75% │ 4.0 │ -├───────────────────────────┼──────────┤ -│ 99% │ 12.6 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 13.7 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 14.7 │ -├───────────────────────────┼──────────┤ -│ max │ 15.5 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 4473 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 4473 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 4473 │ -╘═══════════════════════════╧══════════╛ -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 03:37:13 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 03:37:13 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:00 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -""" diff --git a/egs/swbd/ASR/local/extend_segments.pl b/egs/swbd/ASR/local/extend_segments.pl deleted file mode 100755 index e8b4894d5..000000000 --- a/egs/swbd/ASR/local/extend_segments.pl +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env perl -use warnings; #sed replacement for -w perl parameter - -if (@ARGV != 1 || !($ARGV[0] =~ m/^-?\d+\.?\d*$/ && $ARGV[0] >= 0)) { - print STDERR "Usage: extend_segments.pl time-in-seconds segments.extended \n" . - "e.g. extend_segments.pl 0.25 segments.2\n" . - "This command modifies a segments file, with lines like\n" . - " \n" . - "by extending the beginning and end of each segment by a certain\n" . - "length of time. This script makes sure the output segments do not\n" . - "overlap as a result of this time-extension, and that there are no\n" . - "negative times in the output.\n"; - exit 1; -} - -$extend = $ARGV[0]; - -@all_lines = (); - -while () { - chop; - @A = split(" ", $_); - if (@A != 4) { - die "invalid line in segments file: $_"; - } - $line = @all_lines; # current number of lines. - ($utt_id, $reco_id, $start_time, $end_time) = @A; - - push @all_lines, [ $utt_id, $reco_id, $start_time, $end_time ]; # anonymous array. - if (! defined $lines_for_reco{$reco_id}) { - $lines_for_reco{$reco_id} = [ ]; # push new anonymous array. - } - push @{$lines_for_reco{$reco_id}}, $line; -} - -foreach $reco_id (keys %lines_for_reco) { - $ref = $lines_for_reco{$reco_id}; - @line_numbers = sort { ${$all_lines[$a]}[2] <=> ${$all_lines[$b]}[2] } @$ref; - - - { - # handle start of earliest segment as a special case. - $l0 = $line_numbers[0]; - $tstart = ${$all_lines[$l0]}[2] - $extend; - if ($tstart < 0.0) { $tstart = 0.0; } - ${$all_lines[$l0]}[2] = $tstart; - } - { - # handle end of latest segment as a special case. - $lN = $line_numbers[$#line_numbers]; - $tend = ${$all_lines[$lN]}[3] + $extend; - ${$all_lines[$lN]}[3] = $tend; - } - for ($i = 0; $i < $#line_numbers; $i++) { - $ln = $line_numbers[$i]; - $ln1 = $line_numbers[$i+1]; - $tend = ${$all_lines[$ln]}[3]; # end of earlier segment. - $tstart = ${$all_lines[$ln1]}[2]; # start of later segment. - if ($tend > $tstart) { - $utt1 = ${$all_lines[$ln]}[0]; - $utt2 = ${$all_lines[$ln1]}[0]; - print STDERR "Warning: for utterances $utt1 and $utt2, segments " . - "already overlap; leaving these times unchanged.\n"; - } else { - $my_extend = $extend; - $max_extend = 0.5 * ($tstart - $tend); - if ($my_extend > $max_extend) { $my_extend = $max_extend; } - $tend += $my_extend; - $tstart -= $my_extend; - ${$all_lines[$ln]}[3] = $tend; - ${$all_lines[$ln1]}[2] = $tstart; - } - } -} - -# leave the numbering of the lines unchanged. -for ($l = 0; $l < @all_lines; $l++) { - $ref = $all_lines[$l]; - ($utt_id, $reco_id, $start_time, $end_time) = @$ref; - printf("%s %s %.2f %.2f\n", $utt_id, $reco_id, $start_time, $end_time); -} - -__END__ - -# testing below. - -# ( echo a1 A 0 1; echo a2 A 3 4; echo b1 B 0 1; echo b2 B 2 3 ) | local/extend_segments.pl 1.0 -a1 A 0.00 2.00 -a2 A 2.00 5.00 -b1 B 0.00 1.50 -b2 B 1.50 4.00 -# ( echo a1 A 0 2; echo a2 A 1 3 ) | local/extend_segments.pl 1.0 -Warning: for utterances a1 and a2, segments already overlap; leaving these times unchanged. -a1 A 0.00 2.00 -a2 A 1.00 4.00 -# ( echo a1 A 0 2; echo a2 A 5 6; echo a3 A 3 4 ) | local/extend_segments.pl 1.0 -a1 A 0.00 2.50 -a2 A 4.50 7.00 -a3 A 2.50 4.50 diff --git a/egs/swbd/ASR/local/filter_cuts.py b/egs/swbd/ASR/local/filter_cuts.py deleted file mode 100755 index fbcc9e24a..000000000 --- a/egs/swbd/ASR/local/filter_cuts.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script removes short and long utterances from a cutset. - -Caution: - You may need to tune the thresholds for your own dataset. - -Usage example: - - python3 ./local/filter_cuts.py \ - --bpe-model data/lang_bpe_500/bpe.model \ - --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ - --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=Path, - help="Path to the bpe.model", - ) - - parser.add_argument( - "--in-cuts", - type=Path, - help="Path to the input cutset", - ) - - parser.add_argument( - "--out-cuts", - type=Path, - help="Path to the output cutset", - ) - - return parser.parse_args() - - -def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): - total = 0 # number of total utterances before removal - removed = 0 # number of removed utterances - - def remove_short_and_long_utterances(c: Cut): - """Return False to exclude the input cut""" - nonlocal removed, total - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ./display_manifest_statistics.py - # - # You should use ./display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - total += 1 - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - removed += 1 - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./pruned_transducer_stateless2/conformer.py, the - # conv module uses the following expression - # for subsampling - if c.num_frames is None: - num_frames = c.duration * 100 # approximate - else: - num_frames = c.num_frames - - T = ((num_frames - 1) // 2 - 1) // 2 - # Note: for ./lstm_transducer_stateless/lstm.py, the formula is - # T = ((num_frames - 3) // 2 - 1) // 2 - - # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is - # T = ((num_frames - 7) // 2 + 1) // 2 - - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - removed += 1 - return False - - return True - - # We use to_eager() here so that we can print out the value of total - # and removed below. - ans = cut_set.filter(remove_short_and_long_utterances).to_eager() - ratio = removed / total * 100 - logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." - ) - return ans - - -def main(): - args = get_args() - logging.info(vars(args)) - - if args.out_cuts.is_file(): - logging.info(f"{args.out_cuts} already exists - skipping") - return - - assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" - assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" - - sp = spm.SentencePieceProcessor() - sp.load(str(args.bpe_model)) - - cut_set = load_manifest_lazy(args.in_cuts) - assert isinstance(cut_set, CutSet) - - cut_set = filter_cuts(cut_set, sp) - logging.info(f"Saving to {args.out_cuts}") - args.out_cuts.parent.mkdir(parents=True, exist_ok=True) - cut_set.to_file(args.out_cuts) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py deleted file mode 100755 index 13b35980b..000000000 --- a/egs/swbd/ASR/local/filter_empty_text.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 The Chinese University of Hong Kong (author: Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from typing import List - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--kaldi-data-dir", - type=Path, - required=True, - help="Path to the kaldi data dir", - ) - - return parser.parse_args() - - -def load_segments(path: Path): - segments = {} - with open(path, "r") as f: - lines = f.readlines() - for line in lines: - line = line.strip() - utt_id, rec_id, start, end = line.split() - segments[utt_id] = line - return segments - - -def filter_text(path: Path): - with open(path, "r") as f: - lines = f.readlines() - return list(filter(lambda x: len(x.strip().split()) > 1, lines)) - - -def write_segments(path: Path, texts: List[str]): - with open(path, "w") as f: - f.writelines(texts) - - -def main(): - args = get_args() - orig_text_dict = filter_text(args.kaldi_data_dir / "text") - write_segments(args.kaldi_data_dir / "text", orig_text_dict) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() - - logging.info("Empty lines filtered") diff --git a/egs/swbd/ASR/local/format_acronyms_dict.py b/egs/swbd/ASR/local/format_acronyms_dict.py deleted file mode 100755 index fa598dd03..000000000 --- a/egs/swbd/ASR/local/format_acronyms_dict.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2015 Minhua Wu -# Apache 2.0 - -# convert acronyms in swbd dict to fisher convention -# IBM to i._b._m. -# BBC to b._b._c. -# BBCs to b._b._c.s -# BBC's to b._b._c.'s - -import argparse -import re - -__author__ = "Minhua Wu" - -parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") -parser.add_argument("-i", "--input", help="Input lexicon", required=True) -parser.add_argument("-o", "--output", help="Output lexicon", required=True) -parser.add_argument( - "-L", "--Letter", help="Input single letter pronunciation", required=True -) -parser.add_argument("-M", "--Map", help="Output acronyms mapping", required=True) -args = parser.parse_args() - - -fin_lex = open(args.input, "r") -fin_Letter = open(args.Letter, "r") -fout_lex = open(args.output, "w") -fout_map = open(args.Map, "w") - -# Initialise single letter dictionary -dict_letter = {} -for single_letter_lex in fin_Letter: - items = single_letter_lex.split() - dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1 :].strip() -fin_Letter.close() -# print dict_letter - -for lex in fin_lex: - items = lex.split() - word = items[0] - lexicon = lex[len(items[0]) + 1 :].strip() - # find acronyms from words with only letters and ' - pre_match = re.match(r"^[A-Za-z]+$|^[A-Za-z]+\'s$|^[A-Za-z]+s$", word) - if pre_match: - # find if words in the form of xxx's is acronym - if word[-2:] == "'s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): - actual_word = word[:-2] - actual_lexicon = lexicon[:-2] - acronym_lexicon = "" - for w in actual_word: - acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " - if acronym_lexicon.strip() == actual_lexicon: - acronym_mapped = "" - acronym_mapped_back = "" - for w in actual_word[:-1]: - acronym_mapped = acronym_mapped + w.lower() + "._" - acronym_mapped_back = acronym_mapped_back + w.lower() + " " - acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".'s" - acronym_mapped_back = ( - acronym_mapped_back + actual_word[-1].lower() + "'s" - ) - fout_map.write( - word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" - ) - fout_lex.write(acronym_mapped + " " + lexicon + "\n") - else: - fout_lex.write(lex) - - # find if words in the form of xxxs is acronym - elif word[-1] == "s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): - actual_word = word[:-1] - actual_lexicon = lexicon[:-2] - acronym_lexicon = "" - for w in actual_word: - acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " - if acronym_lexicon.strip() == actual_lexicon: - acronym_mapped = "" - acronym_mapped_back = "" - for w in actual_word[:-1]: - acronym_mapped = acronym_mapped + w.lower() + "._" - acronym_mapped_back = acronym_mapped_back + w.lower() + " " - acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".s" - acronym_mapped_back = ( - acronym_mapped_back + actual_word[-1].lower() + "'s" - ) - fout_map.write( - word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" - ) - fout_lex.write(acronym_mapped + " " + lexicon + "\n") - else: - fout_lex.write(lex) - - # find if words in the form of xxx (not ended with 's or s) is acronym - elif word.find("'") == -1 and word[-1] != "s": - acronym_lexicon = "" - for w in word: - acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " - if acronym_lexicon.strip() == lexicon: - acronym_mapped = "" - acronym_mapped_back = "" - for w in word[:-1]: - acronym_mapped = acronym_mapped + w.lower() + "._" - acronym_mapped_back = acronym_mapped_back + w.lower() + " " - acronym_mapped = acronym_mapped + word[-1].lower() + "." - acronym_mapped_back = acronym_mapped_back + word[-1].lower() - fout_map.write( - word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" - ) - fout_lex.write(acronym_mapped + " " + lexicon + "\n") - else: - fout_lex.write(lex) - else: - fout_lex.write(lex) - - else: - fout_lex.write(lex) diff --git a/egs/swbd/ASR/local/generate_unique_lexicon.py b/egs/swbd/ASR/local/generate_unique_lexicon.py deleted file mode 100755 index 3459c2f5a..000000000 --- a/egs/swbd/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file takes as input a lexicon.txt and output a new lexicon, -in which each word has a unique pronunciation. - -The way to do this is to keep only the first pronunciation of a word -in lexicon.txt. -""" - - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -from icefall.lexicon import read_lexicon, write_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - This file will generate a new file uniq_lexicon.txt - in it. - """, - ) - - return parser.parse_args() - - -def filter_multiple_pronunications( - lexicon: List[Tuple[str, List[str]]] -) -> List[Tuple[str, List[str]]]: - """Remove multiple pronunciations of words from a lexicon. - - If a word has more than one pronunciation in the lexicon, only - the first one is kept, while other pronunciations are removed - from the lexicon. - - Args: - lexicon: - The input lexicon, containing a list of (word, [p1, p2, ..., pn]), - where "p1, p2, ..., pn" are the pronunciations of the "word". - Returns: - Return a new lexicon where each word has a unique pronunciation. - """ - seen = set() - ans = [] - - for word, tokens in lexicon: - if word in seen: - continue - seen.add(word) - ans.append((word, tokens)) - return ans - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - lexicon_filename = lang_dir / "lexicon.txt" - - in_lexicon = read_lexicon(lexicon_filename) - - out_lexicon = filter_multiple_pronunications(in_lexicon) - - write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) - - logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") - logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/map_acronyms_transcripts.py b/egs/swbd/ASR/local/map_acronyms_transcripts.py deleted file mode 100755 index ba02aaec3..000000000 --- a/egs/swbd/ASR/local/map_acronyms_transcripts.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2015 Minhua Wu -# Apache 2.0 - -# convert acronyms in swbd transcript to fisher convention -# according to first two columns in the input acronyms mapping - -import argparse -import re - -__author__ = "Minhua Wu" - -parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") -parser.add_argument("-i", "--input", help="Input transcripts", required=True) -parser.add_argument("-o", "--output", help="Output transcripts", required=True) -parser.add_argument("-M", "--Map", help="Input acronyms mapping", required=True) -args = parser.parse_args() - -fin_map = open(args.Map, "r") -dict_acronym = {} -dict_acronym_noi = {} # Mapping of acronyms without I, i -for pair in fin_map: - items = pair.split("\t") - dict_acronym[items[0]] = items[1] - dict_acronym_noi[items[0]] = items[1] -fin_map.close() -del dict_acronym_noi["I"] -del dict_acronym_noi["i"] - - -fin_trans = open(args.input, "r") -fout_trans = open(args.output, "w") -for line in fin_trans: - items = line.split() - L = len(items) - # First pass mapping to map I as part of acronym - for i in range(L): - if items[i] == "I": - x = 0 - while i - 1 - x >= 0 and re.match(r"^[A-Z]$", items[i - 1 - x]): - x += 1 - - y = 0 - while i + 1 + y < L and re.match(r"^[A-Z]$", items[i + 1 + y]): - y += 1 - - if x + y > 0: - for bias in range(-x, y + 1): - items[i + bias] = dict_acronym[items[i + bias]] - - # Second pass mapping (not mapping 'i' and 'I') - for i in range(len(items)): - if items[i] in dict_acronym_noi.keys(): - items[i] = dict_acronym_noi[items[i]] - sentence = " ".join(items[1:]) - fout_trans.write(items[0] + " " + sentence.lower() + "\n") - -fin_trans.close() -fout_trans.close() diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py deleted file mode 100755 index 20ab90caf..000000000 --- a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py +++ /dev/null @@ -1,283 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import re -from typing import Tuple - -from lhotse import SupervisionSegment, SupervisionSet -from lhotse.serialization import load_manifest_lazy_or_eager -from tqdm import tqdm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("input_sups") - parser.add_argument("output_sups") - return parser.parse_args() - - -# replacement function to convert lowercase letter to uppercase -def to_upper(match_obj): - if match_obj.group() is not None: - return match_obj.group().upper() - - -def insert_groups_and_capitalize_3(match): - return f"{match.group(1)} {match.group(2)} {match.group(3)}".upper() - - -def insert_groups_and_capitalize_2(match): - return f"{match.group(1)} {match.group(2)}".upper() - - -def insert_groups_and_capitalize_1(match): - return f"{match.group(1)}".upper() - - -def insert_groups_and_capitalize_1s(match): - return f"{match.group(1)}".upper() + "'s" - - -class FisherSwbdNormalizer: - """Note: the functions "normalize" and "keep" implement the logic - similar to Kaldi's data prep scripts for Fisher and SWBD: One - notable difference is that we don't change [cough], [lipsmack], - etc. to [noise]. We also don't implement all the edge cases of - normalization from Kaldi (hopefully won't make too much - difference). - """ - - def __init__(self) -> None: - self.remove_regexp_before = re.compile( - r"|".join( - [ - # special symbols - r"\[\[skip.*\]\]", - r"\[skip.*\]", - r"\[pause.*\]", - r"\[silence\]", - r"", - r"", - r"_1", - ] - ) - ) - - # tuples of (pattern, replacement) - # note: Kaldi replaces sighs, coughs, etc with [noise]. - # We don't do that here. - # We also lowercase the text as the first operation. - self.replace_regexps: Tuple[re.Pattern, str] = [ - # SWBD: - # [LAUGHTER-STORY] -> STORY - (re.compile(r"\[laughter-(.*?)\]"), r"\1"), - # [WEA[SONABLE]-/REASONABLE] - (re.compile(r"\[\S+/(\S+)\]"), r"\1"), - # -[ADV]AN[TAGE]- -> AN - (re.compile(r"-?\[.*?\](\w+)\[.*?\]-?"), r"\1-"), - # ABSOLUTE[LY]- -> ABSOLUTE- - (re.compile(r"(\w+)\[.*?\]-?"), r"\1-"), - # [AN]Y- -> Y- - # -[AN]Y- -> Y- - (re.compile(r"-?\[.*?\](\w+)-?"), r"\1-"), - # special tokens - (re.compile(r"\[laugh.*?\]"), r"[laughter]"), - (re.compile(r"\[sigh.*?\]"), r"[sigh]"), - (re.compile(r"\[cough.*?\]"), r"[cough]"), - (re.compile(r"\[mn.*?\]"), r"[vocalized-noise]"), - (re.compile(r"\[breath.*?\]"), r"[breath]"), - (re.compile(r"\[lipsmack.*?\]"), r"[lipsmack]"), - (re.compile(r"\[sneeze.*?\]"), r"[sneeze]"), - # abbreviations - ( - re.compile( - r"(\w)\.(\w)\.(\w)", - ), - insert_groups_and_capitalize_3, - ), - ( - re.compile( - r"(\w)\.(\w)", - ), - insert_groups_and_capitalize_2, - ), - ( - re.compile( - r"([a-h,j-z])\.", - ), - insert_groups_and_capitalize_1, - ), - ( - re.compile( - r"\._", - ), - r" ", - ), - ( - re.compile( - r"_(\w)", - ), - insert_groups_and_capitalize_1, - ), - ( - re.compile( - r"(\w)\.s", - ), - insert_groups_and_capitalize_1s, - ), - ( - re.compile( - r"([A-Z])\'s", - ), - insert_groups_and_capitalize_1s, - ), - ( - re.compile( - r"(\s\w\b|^\w\b)", - ), - insert_groups_and_capitalize_1, - ), - # words between apostrophes - (re.compile(r"'(\S*?)'"), r"\1"), - # dangling dashes (2 passes) - (re.compile(r"\s-\s"), r" "), - (re.compile(r"\s-\s"), r" "), - # special symbol with trailing dash - (re.compile(r"(\[.*?\])-"), r"\1"), - # Just remove all dashes - (re.compile(r"-"), r" "), - ] - - # unwanted symbols in the transcripts - self.remove_regexp_after = re.compile( - r"|".join( - [ - # remaining punctuation - r"\.", - r",", - r"\?", - r"{", - r"}", - r"~", - r"_\d", - ] - ) - ) - - self.post_fixes = [ - # Fix an issue related to [VOCALIZED NOISE] after dash removal - (re.compile(r"\[vocalized noise\]"), "[vocalized-noise]"), - ] - - self.whitespace_regexp = re.compile(r"\s+") - - def normalize(self, text: str) -> str: - text = text.lower() - - # first remove - text = self.remove_regexp_before.sub("", text) - - # then replace - for pattern, sub in self.replace_regexps: - text = pattern.sub(sub, text) - - # then remove - text = self.remove_regexp_after.sub("", text) - - # post fixes - for pattern, sub in self.post_fixes: - text = pattern.sub(sub, text) - - # then clean up whitespace - text = self.whitespace_regexp.sub(" ", text).strip() - - return text.upper() - - -def keep(sup: SupervisionSegment) -> bool: - if "((" in sup.text: - return False - - if " yes", - "[laugh] oh this is [laught] this is great [silence] yes", - "i don't kn- - know A.B.C's", - "so x. corp is good?", - "'absolutely yes", - "absolutely' yes", - "'absolutely' yes", - "'absolutely' yes 'aight", - "ABSOLUTE[LY]", - "ABSOLUTE[LY]-", - "[AN]Y", - "[AN]Y-", - "[ADV]AN[TAGE]", - "[ADV]AN[TAGE]-", - "-[ADV]AN[TAGE]", - "-[ADV]AN[TAGE]-", - "[WEA[SONABLE]-/REASONABLE]", - "[VOCALIZED-NOISE]-", - "~BULL", - "Frank E Peretti P E R E T T I", - "yeah yeah like Double O Seven he's supposed to do it", - "P A P E R paper", - "[noise] okay_1 um let me see [laughter] i've been sitting here awhile", - ]: - print(text) - print(normalizer.normalize(text)) - print() - - -if __name__ == "__main__": - test() - # exit() - main() diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py deleted file mode 100755 index 7316193d0..000000000 --- a/egs/swbd/ASR/local/normalize_eval2000.py +++ /dev/null @@ -1,234 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import re -from typing import Tuple - -from lhotse import SupervisionSegment, SupervisionSet -from lhotse.serialization import load_manifest_lazy_or_eager -from tqdm import tqdm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("input_sups") - parser.add_argument("output_sups") - return parser.parse_args() - - -def remove_punctutation_and_other_symbol(text: str) -> str: - text = text.replace("--", " ") - text = text.replace("//", " ") - text = text.replace(".", " ") - text = text.replace("?", " ") - text = text.replace("~", " ") - text = text.replace(",", " ") - text = text.replace(";", " ") - text = text.replace("(", " ") - text = text.replace(")", " ") - text = text.replace("&", " ") - text = text.replace("%", " ") - text = text.replace("*", " ") - text = text.replace("{", " ") - text = text.replace("}", " ") - return text - - -def eval2000_clean_eform(text: str, eform_count) -> str: - string_to_remove = [] - piece = text.split('">') - for i in range(0, len(piece)): - s = piece[i] + '">' - res = re.search(r"", s) - if res is not None: - res_rm = res.group(1) - string_to_remove.append(res_rm) - for p in string_to_remove: - eform_string = p - text = text.replace(eform_string, " ") - eform_1 = " str: - text = text.replace("[/BABY CRYING]", " ") - text = text.replace("[/CHILD]", " ") - text = text.replace("[[DISTORTED]]", " ") - text = text.replace("[/DISTORTION]", " ") - text = text.replace("[[DRAWN OUT]]", " ") - text = text.replace("[[DRAWN-OUT]]", " ") - text = text.replace("[[FAINT]]", " ") - text = text.replace("[SMACK]", " ") - text = text.replace("[[MUMBLES]]", " ") - text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") - text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") - text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") - text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " ") - text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") - text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") - text = text.replace("[[PROLONGED]]", " ") - text = text.replace("[/RUNNING WATER]", " ") - text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]") - text = text.replace("[[SINGING]]", " ") - text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]") - text = text.replace("[/STATIC]", " ") - text = text.replace("['THIRTIETH' DRAWN OUT]", " ") - text = text.replace("[/VOICES]", " ") - text = text.replace("[[WHISPERED]]", " ") - text = text.replace("[DISTORTION]", " ") - text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ") - text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]") - text = text.replace("[CHILD'S VOICE]", " ") - text = text.replace("[CHILD SCREAMS]", " ") - text = text.replace("[CHILD VOICE]", " ") - text = text.replace("[CHILD YELLING]", " ") - text = text.replace("[CHILD SCREAMING]", " ") - text = text.replace("[CHILD'S VOICE IN BACKGROUND]", " ") - text = text.replace("[CHANNEL NOISE]", " ") - text = text.replace("[CHANNEL ECHO]", " ") - text = text.replace("[ECHO FROM OTHER CHANNEL]", " ") - text = text.replace("[ECHO OF OTHER CHANNEL]", " ") - text = text.replace("[CLICK]", " ") - text = text.replace("[DISTORTED]", " ") - text = text.replace("[BABY CRYING]", " ") - text = text.replace("[METALLIC KNOCKING SOUND]", " ") - text = text.replace("[METALLIC SOUND]", " ") - - text = text.replace("[PHONE JIGGLING]", " ") - text = text.replace("[BACKGROUND SOUND]", " ") - text = text.replace("[BACKGROUND VOICE]", " ") - text = text.replace("[BACKGROUND VOICES]", " ") - text = text.replace("[BACKGROUND NOISE]", " ") - text = text.replace("[CAR HORNS IN BACKGROUND]", " ") - text = text.replace("[CAR HORNS]", " ") - text = text.replace("[CARNATING]", " ") - text = text.replace("[CRYING CHILD]", " ") - text = text.replace("[CHOPPING SOUND]", " ") - text = text.replace("[BANGING]", " ") - text = text.replace("[CLICKING NOISE]", " ") - text = text.replace("[CLATTERING]", " ") - text = text.replace("[ECHO]", " ") - text = text.replace("[KNOCK]", " ") - text = text.replace("[NOISE-GOOD]", "[NOISE]") - text = text.replace("[RIGHT]", " ") - text = text.replace("[SOUND]", " ") - text = text.replace("[SQUEAK]", " ") - text = text.replace("[STATIC]", " ") - text = text.replace("[[SAYS WITH HIGH-PITCHED SCREAMING LAUGHTER]]", " ") - text = text.replace("[UH]", "UH") - text = text.replace("[MN]", "[VOCALIZED-NOISE]") - text = text.replace("[VOICES]", " ") - text = text.replace("[WATER RUNNING]", " ") - text = text.replace("[SOUND OF TWISTING PHONE CORD]", " ") - text = text.replace("[SOUND OF SOMETHING FALLING]", " ") - text = text.replace("[SOUND]", " ") - text = text.replace("[NOISE OF MOVING PHONE]", " ") - text = text.replace("[SOUND OF RUNNING WATER]", " ") - text = text.replace("[CHANNEL]", " ") - text = text.replace("[SILENCE]", " ") - text = text.replace("-[W]HERE", "WHERE") - text = text.replace("Y[OU]I-", "YOU I") - text = text.replace("-[A]ND", "AND") - text = text.replace("JU[ST]", "JUST") - text = text.replace("{BREATH}", " ") - text = text.replace("{BREATHY}", " ") - text = text.replace("{CHANNEL NOISE}", " ") - text = text.replace("{CLEAR THROAT}", " ") - - text = text.replace("{CLEARING THROAT}", " ") - text = text.replace("{CLEARS THROAT}", " ") - text = text.replace("{COUGH}", " ") - text = text.replace("{DRAWN OUT}", " ") - text = text.replace("{EXHALATION}", " ") - text = text.replace("{EXHALE}", " ") - text = text.replace("{GASP}", " ") - text = text.replace("{HIGH SQUEAL}", " ") - text = text.replace("{INHALE}", " ") - text = text.replace("{LAUGH}", "[LAUGHTER]") - text = text.replace("{LAUGH}", "[LAUGHTER]") - text = text.replace("{LAUGH}", "[LAUGHTER]") - text = text.replace("{LIPSMACK}", " ") - text = text.replace("{LIPSMACK}", " ") - - text = text.replace("{NOISE OF DISGUST}", " ") - text = text.replace("{SIGH}", " ") - text = text.replace("{SNIFF}", " ") - text = text.replace("{SNORT}", " ") - text = text.replace("{SHARP EXHALATION}", " ") - text = text.replace("{BREATH LAUGH}", " ") - - text = text.replace("[LAUGHTER]", " ") - text = text.replace("[NOISE]", " ") - text = text.replace("[VOCALIZED-NOISE]", " ") - text = text.replace("-", " ") - return text - - -def remove_languagetag(text: str) -> str: - langtag = re.findall(r"<(.*?)>", text) - for t in langtag: - text = text.replace(t, " ") - text = text.replace("<", " ") - text = text.replace(">", " ") - return text - - -def eval2000_normalizer(text: str) -> str: - # print("TEXT original: ",text) - eform_count = text.count("contraction e_form") - # print("eform corunt:", eform_count) - if eform_count > 0: - text = eval2000_clean_eform(text, eform_count) - text = text.upper() - text = remove_languagetag(text) - text = replace_silphone(text) - text = remove_punctutation_and_other_symbol(text) - text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ") - text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ") - spaces = re.findall(r"\s+", text) - for sp in spaces: - text = text.replace(sp, " ") - text = text.strip() - # text = self.whitespace_regexp.sub(" ", text).strip() - # print(text) - return text - - -def main(): - args = get_args() - sups = load_manifest_lazy_or_eager(args.input_sups) - assert isinstance(sups, SupervisionSet) - - tot, skip = 0, 0 - with SupervisionSet.open_writer(args.output_sups) as writer: - for sup in tqdm(sups, desc="Normalizing supervisions"): - tot += 1 - sup.text = eval2000_normalizer(sup.text) - if not sup.text: - skip += 1 - continue - writer.write(sup) - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/swbd/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py deleted file mode 100755 index d82a085ec..000000000 --- a/egs/swbd/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1,274 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/bpe.model, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import sentencepiece as spm -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - -from icefall.utils import str2bool - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def generate_lexicon( - model_file: str, words: List[str], oov: str -) -> Tuple[Lexicon, Dict[str, int]]: - """Generate a lexicon from a BPE model. - - Args: - model_file: - Path to a sentencepiece model. - words: - A list of strings representing words. - oov: - The out of vocabulary word in lexicon. - Returns: - Return a tuple with two elements: - - A dict whose keys are words and values are the corresponding - word pieces. - - A dict representing the token symbol, mapping from tokens to IDs. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - - # Convert word to word piece IDs instead of word piece strings - # to avoid OOV tokens. - words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) - - # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] - - lexicon = [] - for word, pieces in zip(words, words_pieces): - lexicon.append((word, pieces)) - - lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) - - token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} - - return lexicon, token2id - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the bpe.model and words.txt - """, - ) - - parser.add_argument( - "--oov", - type=str, - default="", - help="The out of vocabulary word in lexicon.", - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - - See "test/test_bpe_lexicon.py" for usage. - """, - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - model_file = lang_dir / "bpe.model" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = [ - "", - "!SIL", - "", - args.oov, - "#0", - "", - "", - ] - - for w in excluded: - if w in words: - words.remove(w) - - lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py deleted file mode 120000 index abc00d421..000000000 --- a/egs/swbd/ASR/local/prepare_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/rt03_data_prep.sh b/egs/swbd/ASR/local/rt03_data_prep.sh deleted file mode 100755 index 8a5f64324..000000000 --- a/egs/swbd/ASR/local/rt03_data_prep.sh +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env bash - -# RT-03 data preparation (conversational telephone speech part only) -# Adapted from Arnab Ghoshal's script for Hub-5 Eval 2000 by Peng Qi - -# To be run from one directory above this script. - -# Expects the standard directory layout for RT-03 - -if [ $# -ne 1 ]; then - echo "Usage: $0 " - echo "e.g.: $0 /export/corpora/LDC/LDC2007S10" - echo "See comments in the script for more details" - exit 1 -fi - -sdir=$1 -[ ! -d $sdir/data/audio/eval03/english/cts ] && - echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1 -[ ! -d $sdir/data/references/eval03/english/cts ] && - echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1 - -dir=data/local/rt03 -mkdir -p $dir - -rtroot=$sdir -tdir=$sdir/data/references/eval03/english/cts -sdir=$sdir/data/audio/eval03/english/cts - -find -L $sdir -iname '*.sph' | sort >$dir/sph.flist -sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ - >$dir/sph.scp - -sph2pipe=sph2pipe -! command -v "${sph2pipe}" &>/dev/null && - echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 - -awk -v sph2pipe=$sph2pipe '{ - printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); - printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); -}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 -#side A - channel 1, side B - channel 2 - -# Get segments file... -# segments file format is: utt-id side-id start-time end-time, e.g.: -# sw02001-A_000098-001156 sw02001-A 0.98 11.56 -#pem=$sdir/english/hub5e_00.pem -#[ ! -f $pem ] && echo "No such file $pem" && exit 1; -# pem file has lines like: -# en_4156 A unknown_speaker 301.85 302.48 - -#grep -v ';;' $pem \ -cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | - awk '{ - spk=$1"-"(($2==1)?"A":"B"); - utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - print utt,spk,$4,$5;}' | - sort -u >$dir/segments - -# stm file has lines like: -# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER -# TODO(arnab): We should really be lowercasing this since the Edinburgh -# recipe uses lowercase. This is not used in the actual scoring. -#grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ -cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | - awk '{ - spk=$1"-"(($2==1)?"A":"B"); - utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | - sort >$dir/text.all - -# We'll use the stm file for sclite scoring. There seem to be various errors -# in the stm file that upset hubscr.pl, and we fix them here. -cat $tdir/*.stm | - sed -e 's:((:(:' -e 's:::g' -e 's:::g' | - grep -v inter_segment_gap | - awk '{ - printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }' \ - >$dir/stm -#$tdir/reference/hub5e00.english.000405.stm > $dir/stm -cp $rtroot/data/trans_rules/en20030506.glm $dir/glm - -# next line uses command substitution -# Just checking that the segments are the same in pem vs. stm. -! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && - echo "Segments from pem file and stm file do not match." && exit 1 - -grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text - -# create an utt2spk file that assumes each conversation side is -# a separate speaker. -awk '{print $1,$2;}' $dir/segments >$dir/utt2spk -utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt - -# cp $dir/segments $dir/segments.tmp -# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ -# $dir/segments.tmp > $dir/segments - -awk '{print $1}' $dir/wav.scp | - perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; - print "$1-$2 $1 $2\n"; ' \ - >$dir/reco2file_and_channel || exit 1 - -./utils/fix_data_dir.sh $dir - -echo Data preparation and formatting completed for RT-03 -echo "(but not MFCC extraction)" diff --git a/egs/swbd/ASR/local/sort_lm_training_data.py b/egs/swbd/ASR/local/sort_lm_training_data.py deleted file mode 100755 index bed3856e4..000000000 --- a/egs/swbd/ASR/local/sort_lm_training_data.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file takes as input the filename of LM training data -generated by ./local/prepare_lm_training_data.py and sorts -it by sentence length. - -Sentence length equals to the number of BPE tokens in a sentence. -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import numpy as np -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--in-lm-data", - type=str, - help="Input LM training data, e.g., data/bpe_500/lm_data.pt", - ) - - parser.add_argument( - "--out-lm-data", - type=str, - help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", - ) - - parser.add_argument( - "--out-statistics", - type=str, - help="Statistics about LM training data., data/bpe_500/statistics.txt", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - in_lm_data = Path(args.in_lm_data) - out_lm_data = Path(args.out_lm_data) - assert in_lm_data.is_file(), f"{in_lm_data}" - if out_lm_data.is_file(): - logging.warning(f"{out_lm_data} exists - skipping") - return - data = torch.load(in_lm_data) - words2bpe = data["words"] - sentences = data["sentences"] - sentence_lengths = data["sentence_lengths"] - - num_sentences = sentences.dim0 - assert num_sentences == sentence_lengths.numel(), ( - num_sentences, - sentence_lengths.numel(), - ) - - indices = torch.argsort(sentence_lengths, descending=True) - - sorted_sentences = sentences[indices.to(torch.int32)] - sorted_sentence_lengths = sentence_lengths[indices] - - # Check that sentences are ordered by length - assert num_sentences == sorted_sentences.dim0, ( - num_sentences, - sorted_sentences.dim0, - ) - - cur = None - for i in range(num_sentences): - word_ids = sorted_sentences[i] - token_ids = words2bpe[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - if cur is not None: - assert cur >= token_ids.numel(), (cur, token_ids.numel()) - - cur = token_ids.numel() - assert cur == sorted_sentence_lengths[i] - - data["sentences"] = sorted_sentences - data["sentence_lengths"] = sorted_sentence_lengths - torch.save(data, args.out_lm_data) - logging.info(f"Saved to {args.out_lm_data}") - - statistics = Path(args.out_statistics) - - # Write statistics - num_words = sorted_sentences.numel() - num_tokens = sentence_lengths.sum().item() - max_sentence_length = sentence_lengths[indices[0]] - min_sentence_length = sentence_lengths[indices[-1]] - - step = 10 - hist, bins = np.histogram( - sentence_lengths.numpy(), - bins=np.arange(1, max_sentence_length + step, step), - ) - - histogram = np.stack((bins[:-1], hist)).transpose() - - with open(statistics, "w") as f: - f.write(f"num_sentences: {num_sentences}\n") - f.write(f"num_words: {num_words}\n") - f.write(f"num_tokens: {num_tokens}\n") - f.write(f"max_sentence_length: {max_sentence_length}\n") - f.write(f"min_sentence_length: {min_sentence_length}\n") - f.write("histogram:\n") - f.write(" bin count percent\n") - for row in histogram: - f.write( - f"{int(row[0]):>5} {int(row[1]):>5} " - f"{100.*row[1]/num_sentences:.3f}%\n" - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/local/swbd1_data_prep.sh b/egs/swbd/ASR/local/swbd1_data_prep.sh deleted file mode 100755 index 159359491..000000000 --- a/egs/swbd/ASR/local/swbd1_data_prep.sh +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env bash - -# Switchboard-1 training data preparation customized for Edinburgh -# Author: Arnab Ghoshal (Jan 2013) - -# To be run from one directory above this script. - -## The input is some directory containing the switchboard-1 release 2 -## corpus (LDC97S62). Note: we don't make many assumptions about how -## you unpacked this. We are just doing a "find" command to locate -## the .sph files. - -## The second input is optional, which should point to a directory containing -## Switchboard transcriptions/documentations (specifically, the conv.tab file). -## If specified, the script will try to use the actual speaker PINs provided -## with the corpus instead of the conversation side ID (Kaldi default). We -## will be using "find" to locate this file so we don't make any assumptions -## on the directory structure. (Peng Qi, Aug 2014) - -#check existing directories -if [ $# != 1 -a $# != 2 ]; then - echo "Usage: swbd1_data_prep.sh /path/to/SWBD [/path/to/SWBD_DOC]" - exit 1 -fi - -SWBD_DIR=$1 - -dir=data/local/train -mkdir -p $dir - -# Audio data directory check -if [ ! -d $SWBD_DIR ]; then - echo "Error: run.sh requires a directory argument" - exit 1 -fi - -sph2pipe=sph2pipe -! command -v "${sph2pipe}" &>/dev/null && - echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 - -# Option A: SWBD dictionary file check -[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && - echo "SWBD dictionary file does not exist" && exit 1 - -# find sph audio files -find -L $SWBD_DIR -iname '*.sph' | sort >$dir/sph.flist - -n=$(cat $dir/sph.flist | wc -l) -[ $n -ne 2435 ] && [ $n -ne 2438 ] && - echo Warning: expected 2435 or 2438 data data files, found $n - -# (1a) Transcriptions preparation -# make basic transcription file (add segments info) -# **NOTE: In the default Kaldi recipe, everything is made uppercase, while we -# make everything lowercase here. This is because we will be using SRILM which -# can optionally make everything lowercase (but not uppercase) when mapping -# LM vocabs. -awk '{ -name=substr($1,1,6); gsub("^sw","sw0",name); side=substr($1,7,1); -stime=$2; etime=$3; -printf("%s-%s_%06.0f-%06.0f", -name, side, int(100*stime+0.5), int(100*etime+0.5)); -for(i=4;i<=NF;i++) printf(" %s", $i); printf "\n" -}' ./swb_ms98_transcriptions/*/*/*-trans.text >$dir/transcripts1.txt - -# test if trans. file is sorted -export LC_ALL=C -sort -c $dir/transcripts1.txt || exit 1 # check it's sorted. - -# Remove SILENCE, and . - -# Note: we have [NOISE], [VOCALIZED-NOISE], [LAUGHTER], [SILENCE]. -# removing [SILENCE], and the and markers that mark -# speech to somone; we will give phones to the other three (NSN, SPN, LAU). -# There will also be a silence phone, SIL. -# **NOTE: modified the pattern matches to make them case insensitive -cat $dir/transcripts1.txt | - perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; - s///gi; - s///gi; - print;' | - awk '{if(NF > 1) { print; } } ' >$dir/transcripts2.txt - -# **NOTE: swbd1_map_words.pl has been modified to make the pattern matches -# case insensitive -local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt >$dir/text - -# format acronyms in text -python3 local/map_acronyms_transcripts.py -i $dir/text -o $dir/text_map \ - -M data/local/dict_nosp/acronyms.map -mv $dir/text_map $dir/text - -# (1c) Make segment files from transcript -#segments file format is: utt-id side-id start-time end-time, e.g.: -#sw02001-A_000098-001156 sw02001-A 0.98 11.56 -awk '{ -segment=$1; -split(segment,S,"[_-]"); -side=S[2]; audioname=S[1]; startf=S[3]; endf=S[4]; -print segment " " audioname "-" side " " startf/100 " " endf/100 -}' <$dir/text >$dir/segments - -sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ - >$dir/sph.scp - -awk -v sph2pipe=$sph2pipe '{ -printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); -printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); -}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 -#side A - channel 1, side B - channel 2 - -# this file reco2file_and_channel maps recording-id (e.g. sw02001-A) -# to the file name sw02001 and the A, e.g. -# sw02001-A sw02001 A -# In this case it's trivial, but in other corpora the information might -# be less obvious. Later it will be needed for ctm scoring. -awk '{print $1}' $dir/wav.scp | - perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; - print "$1-$2 $1 $2\n"; ' \ - >$dir/reco2file_and_channel || exit 1 - -awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments >$dir/utt2spk || - exit 1 -sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl >$dir/spk2utt || exit 1 - -echo Switchboard-1 data preparation succeeded. - -utils/fix_data_dir.sh data/local/train diff --git a/egs/swbd/ASR/local/swbd1_map_words.pl b/egs/swbd/ASR/local/swbd1_map_words.pl deleted file mode 100755 index 4fb8d4ffe..000000000 --- a/egs/swbd/ASR/local/swbd1_map_words.pl +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env perl - -# Modified from swbd_map_words.pl in Kaldi s5 recipe to make pattern -# matches case-insensitive --Arnab (Jan 2013) - -if ($ARGV[0] eq "-f") { - shift @ARGV; - $field_spec = shift @ARGV; - if ($field_spec =~ m/^\d+$/) { - $field_begin = $field_spec - 1; $field_end = $field_spec - 1; - } - if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) - if ($1 ne "") { - $field_begin = $1 - 1; # Change to zero-based indexing. - } - if ($2 ne "") { - $field_end = $2 - 1; # Change to zero-based indexing. - } - } - if (!defined $field_begin && !defined $field_end) { - die "Bad argument to -f option: $field_spec"; - } -} - - -while (<>) { - @A = split(" ", $_); - for ($n = 0; $n < @A; $n++) { - $a = $A[$n]; - if ( (!defined $field_begin || $n >= $field_begin) - && (!defined $field_end || $n <= $field_end)) { - # e.g. [LAUGHTER-STORY] -> STORY; - $a =~ s:(|\-)^\[LAUGHTER-(.+)\](|\-)$:$1$2$3:i; - # $1 and $3 relate to preserving trailing "-" - $a =~ s:^\[(.+)/.+\](|\-)$:$1$2:; # e.g. [IT'N/ISN'T] -> IT'N ... note, - # 1st part may include partial-word stuff, which we process further below, - # e.g. [LEM[GUINI]-/LINGUINI] - # the (|\_) at the end is to accept and preserve trailing -'s. - $a =~ s:^(|\-)\[[^][]+\](.+)$:-$2:; # e.g. -[AN]Y , note \047 is quote; - # let the leading - be optional on input, as sometimes omitted. - $a =~ s:^(.+)\[[^][]+\](|\-)$:$1-:; # e.g. AB[SOLUTE]- -> AB-; - # let the trailing - be optional on input, as sometimes omitted. - $a =~ s:([^][]+)\[.+\]$:$1:; # e.g. EX[SPECIALLY]-/ESPECIALLY] -> EX- - # which is a mistake in the input. - $a =~ s:^\{(.+)\}$:$1:; # e.g. {YUPPIEDOM} -> YUPPIEDOM - $a =~ s:[A-Z]\[([^][])+\][A-Z]:$1-$3:i; # e.g. AMMU[N]IT- -> AMMU-IT- - $a =~ s:_\d$::; # e.g. THEM_1 -> THEM - } - $A[$n] = $a; - } - print join(" ", @A) . "\n"; -} diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh deleted file mode 100755 index eff5fb5f1..000000000 --- a/egs/swbd/ASR/local/swbd1_prepare_dict.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env bash - -# Formatting the Mississippi State dictionary for use in Edinburgh. Differs -# from the one in Kaldi s5 recipe in that it uses lower-case --Arnab (Jan 2013) - -# To be run from one directory above this script. - -#check existing directories -[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1 - -srcdir=. # This is where we downloaded some stuff.. -dir=./data/local/dict_nosp -mkdir -p $dir -srcdict=$srcdir/swb_ms98_transcriptions/sw-ms98-dict.text - -# assume swbd_p1_data_prep.sh was done already. -[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1 - -cp $srcdict $dir/lexicon0.txt || exit 1 -chmod a+w $dir/lexicon0.txt -patch 0' | sort >$dir/lexicon1.txt || exit 1 - -cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | - grep -v sil >$dir/nonsilence_phones.txt || exit 1 - -( - echo sil - echo spn - echo nsn - echo lau -) >$dir/silence_phones.txt - -echo sil >$dir/optional_silence.txt - -# No "extra questions" in the input to this setup, as we don't -# have stress or tone. -echo -n >$dir/extra_questions.txt - -cp local/MSU_single_letter.txt $dir/ -# Add to the lexicon the silences, noises etc. -# Add single letter lexicon -# The original swbd lexicon does not have precise single letter lexicion -# e.g. it does not have entry of W -( - echo '!SIL SIL' - echo '[VOCALIZED-NOISE] spn' - echo '[NOISE] nsn' - echo '[LAUGHTER] lau' - echo ' spn' -) | - cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt >$dir/lexicon2.txt || exit 1 - -# Map the words in the lexicon. That is-- for each word in the lexicon, we map it -# to a new written form. The transformations we do are: -# remove laughter markings, e.g. -# [LAUGHTER-STORY] -> STORY -# Remove partial-words, e.g. -# -[40]1K W AH N K EY -# becomes -1K -# and -# -[AN]Y IY -# becomes -# -Y -# -[A]B[OUT]- B -# becomes -# -B- -# Also, curly braces, which appear to be used for "nonstandard" -# words or non-words, are removed, e.g. -# {WOLMANIZED} W OW L M AX N AY Z D -# -> WOLMANIZED -# Also, mispronounced words, e.g. -# [YEAM/YEAH] Y AE M -# are changed to just e.g. YEAM, i.e. the orthography -# of the mispronounced version. -# Note-- this is only really to be used in training. The main practical -# reason is to avoid having tons of disambiguation symbols, which -# we otherwise would get because there are many partial words with -# the same phone sequences (most problematic: S). -# Also, map -# THEM_1 EH M -> THEM -# so that multiple pronunciations just have alternate entries -# in the lexicon. - -local/swbd1_map_words.pl -f 1 $dir/lexicon2.txt | sort -u \ - >$dir/lexicon3.txt || exit 1 - -python3 local/format_acronyms_dict.py -i $dir/lexicon3.txt -o $dir/lexicon4.txt \ - -L $dir/MSU_single_letter.txt -M $dir/acronyms_raw.map -cat $dir/acronyms_raw.map | sort -u >$dir/acronyms.map - -(echo 'i ay') | cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u >$dir/lexicon5.txt - -pushd $dir >&/dev/null -ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. -popd >&/dev/null -rm $dir/lexiconp.txt 2>/dev/null -echo Prepared input dictionary and phone-sets for Switchboard phase 1. diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py deleted file mode 100755 index 9b4e28635..000000000 --- a/egs/swbd/ASR/local/train_bpe_model.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - else: - print(f"{model_file} exists - skipping") - return - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/swbd/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh deleted file mode 100755 index 6b6f4ff86..000000000 --- a/egs/swbd/ASR/prepare.sh +++ /dev/null @@ -1,463 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. Most of them can't be downloaded automatically -# as they are not publically available and require a license purchased -# from the LDC. -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=./download -# swbd1_dir="/export/corpora3/LDC/LDC97S62" -swbd1_dir=./download/LDC97S62/ - -# eval2000_dir contains the following files and directories -# downloaded from LDC website: -# - LDC2002S09 -# - hub5e_00 -# - LDC2002T43 -# - reference -eval2000_dir="/export/corpora2/LDC/eval2000" - -rt03_dir="/export/corpora/LDC/LDC2007S10" -fisher_dir="/export/corpora3/LDC/LDC2004T19" - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - # 5000 - # 2000 - 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "swbd1_dir: $swbd1_dir" -log "eval2000_dir: $eval2000_dir" -log "rt03_dir: $rt03_dir" - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare SwitchBoard manifest" - # We assume that you have downloaded the SwitchBoard corpus - # to respective dirs - mkdir -p data/manifests - if [ ! -e data/manifests/.swbd.done ]; then - lhotse prepare switchboard --absolute-paths 1 --omit-silence $swbd1_dir data/manifests/swbd - ./local/normalize_and_filter_supervisions.py \ - data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ - data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz - mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz - mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz - - lhotse cut simple \ - -r data/manifests/swbd/swbd_recordings_all.jsonl.gz \ - -s data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ - data/manifests/swbd/swbd_train_all.jsonl.gz - lhotse cut trim-to-supervisions \ - --discard-overlapping \ - --discard-extra-channels \ - data/manifests/swbd/swbd_train_all.jsonl.gz \ - data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz - - num_splits=16 - mkdir -p data/manifests/swbd_split${num_splits} - lhotse split ${num_splits} \ - data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz \ - data/manifests/swbd_split${num_splits} - - lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 - ./local/normalize_eval2000.py \ - data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ - data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz - - lhotse cut simple \ - -r data/manifests/eval2000/eval2000_recordings_all.jsonl.gz \ - -s data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz \ - data/manifests/eval2000/eval2000_cuts_all.jsonl.gz - - lhotse cut trim-to-supervisions \ - --discard-overlapping \ - --discard-extra-channels \ - data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ - data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz - - sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ - $eval2000_dir/LDC2002T43/reference/hub5e00.english.000405.stm > data/manifests/eval2000/stm - cp $eval2000_dir/LDC2002T43/reference/en20000405_hub5.glm $dir/glm - - # ./local/rt03_data_prep.sh $rt03_dir - - # normalize eval2000 and rt03 texts by - # 1) convert upper to lower - # 2) remove tags (%AH) (%HESITATION) (%UH) - # 3) remove - # 4) remove "(" or ")" - # for x in rt03; do - # cp data/local/${x}/text data/local/${x}/text.org - # paste -d "" \ - # <(cut -f 1 -d" " data/local/${x}/text.org) \ - # <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | - # sed -e 's/\s\+/ /g' >data/local/${x}/text - # rm data/local/${x}/text.org - # done - - # lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests - - touch data/manifests/.swbd.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - if [ ! -e data/manifests/.musan.done ]; then - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3 I: Compute fbank for SwitchBoard" - if [ ! -e data/fbank/.swbd.done ]; then - mkdir -p data/fbank/swbd_split${num_splits}/ - for index in $(seq 1 16); do - ./local/compute_fbank_swbd.py --split-index ${index} & - done - wait - pieces=$(find data/fbank/swbd_split${num_splits} -name "swbd_cuts_all.*.jsonl.gz") - lhotse combine $pieces data/fbank/swbd_cuts_all.jsonl.gz - touch data/fbank/.swbd.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3 II: Compute fbank for eval2000" - if [ ! -e data/fbank/.eval2000.done ]; then - mkdir -p data/fbank/eval2000/ - ./local/compute_fbank_eval2000.py - touch data/fbank/.eval2000.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - mkdir -p data/fbank - if [ ! -e data/fbank/.musan.done ]; then - ./local/compute_fbank_musan.py - touch data/fbank/.musan.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" - lang_dir=data/lang_phone - mkdir -p $lang_dir - - if ! which jq; then - echo "This script is intended to be used with jq but you have not installed jq - Note: in Linux, you can install jq with the following command: - 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - 2. chmod +x ./jq - 3. cp jq /usr/bin" && exit 1 - fi - if [ ! -f $lang_dir/text ] || [ ! -s $lang_dir/text ]; then - log "Prepare text." - gunzip -c data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ - | jq '.text' | sed 's/"//g' > $lang_dir/text - fi - - log "Prepare dict" - ./local/swbd1_prepare_dict.sh - cut -f 2- -d" " $lang_dir/text >${lang_dir}/input.txt - # [noise] nsn - # !sil sil - # spn - cat data/local/dict_nosp/lexicon.txt | sed 's/-//g' | sed 's/\[vocalizednoise\]/\[vocalized-noise\]/g' | - sort | uniq >$lang_dir/lexicon_lower.txt - - cat $lang_dir/lexicon_lower.txt | tr a-z A-Z > $lang_dir/lexicon.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - - cat data/lang_phone/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram token-level P for MMI training" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - >$lang_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/P.arpa - fi - - if [ ! -f $lang_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_dir/P.arpa >$lang_dir/P.fst.txt - fi - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare G" - lang_dir=data/lang_phone - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text ${lang_dir}/input.txt \ - -lm data/lm/3-gram.arpa - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - ./shared/make_kn_lm.py \ - -ngram-order 4 \ - -text ${lang_dir}/input.txt \ - -lm data/lm/4-gram.arpa - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt - fi -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir - done -fi - -# Compile LG for RNN-T fast_beam_search decoding -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Compile LG" - ./local/compile_lg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir - done -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Generate LM training data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - lang_dir=data/lang_bpe_${vocab_size} - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/train.txt ]; then - tail -n 250000 data/lang_phone/input.txt > $out_dir/train.txt - fi - - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data data/lang_phone/input.txt \ - --lm-archive $out_dir/lm_data.pt - done -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Generate LM validation data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/valid.txt ]; then - head -n 14332 data/lang_phone/input.txt > $out_dir/valid.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/valid.txt \ - --lm-archive $out_dir/lm_data-valid.pt - done -fi - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Generate LM test data" - testsets=(eval2000) - - for testset in ${testsets[@]}; do - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/${testset}.txt ]; then - gunzip -c data/manifests/${testset}/eval2000_supervisions_all.jsonl.gz \ - | jq '.text' | sed 's/"//g' > $out_dir/${testset}.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/${testset}.txt \ - --lm-archive $out_dir/lm_data-${testset}.pt - done - done -fi - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Sort LM training data" - testsets=(eval2000) - # Sort LM training data by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of BPE tokens - # in a sentence. - - for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - for testset in ${testsets[@]}; do - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-${testset}.pt \ - --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ - --out-statistics $out_dir/statistics-test-${testset}.txt - done - done -fi diff --git a/egs/swbd/ASR/shared b/egs/swbd/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/swbd/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/swbd/ASR/utils/filter_scp.pl b/egs/swbd/ASR/utils/filter_scp.pl deleted file mode 100755 index b76d37f41..000000000 --- a/egs/swbd/ASR/utils/filter_scp.pl +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2012 Microsoft Corporation -# Johns Hopkins University (author: Daniel Povey) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# This script takes a list of utterance-ids or any file whose first field -# of each line is an utterance-id, and filters an scp -# file (or any file whose "n-th" field is an utterance id), printing -# out only those lines whose "n-th" field is in id_list. The index of -# the "n-th" field is 1, by default, but can be changed by using -# the -f switch - -$exclude = 0; -$field = 1; -$shifted = 0; - -do { - $shifted=0; - if ($ARGV[0] eq "--exclude") { - $exclude = 1; - shift @ARGV; - $shifted=1; - } - if ($ARGV[0] eq "-f") { - $field = $ARGV[1]; - shift @ARGV; shift @ARGV; - $shifted=1 - } -} while ($shifted); - -if(@ARGV < 1 || @ARGV > 2) { - die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . - "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . - "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . - "only the lines that were *not* in id_list.\n" . - "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . - "If your older scripts (written before Oct 2014) stopped working and you used the\n" . - "-f option, add 1 to the argument.\n" . - "See also: utils/filter_scp.pl .\n"; -} - - -$idlist = shift @ARGV; -open(F, "<$idlist") || die "Could not open id-list file $idlist"; -while() { - @A = split; - @A>=1 || die "Invalid id-list file line $_"; - $seen{$A[0]} = 1; -} - -if ($field == 1) { # Treat this as special case, since it is common. - while(<>) { - $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; - # $1 is what we filter on. - if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { - print $_; - } - } -} else { - while(<>) { - @A = split; - @A > 0 || die "Invalid scp file line $_"; - @A >= $field || die "Invalid scp file line $_"; - if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { - print $_; - } - } -} - -# tests: -# the following should print "foo 1" -# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) -# the following should print "bar 2". -# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/swbd/ASR/utils/fix_data_dir.sh b/egs/swbd/ASR/utils/fix_data_dir.sh deleted file mode 100755 index ca0972ca8..000000000 --- a/egs/swbd/ASR/utils/fix_data_dir.sh +++ /dev/null @@ -1,197 +0,0 @@ -#!/bin/bash - -# This script makes sure that only the segments present in -# all of "feats.scp", "wav.scp" [if present], segments [if present] -# text, and utt2spk are present in any of them. -# It puts the original contents of data-dir into -# data-dir/.backup - -cmd="$@" - -utt_extra_files= -spk_extra_files= - -. utils/parse_options.sh - -if [ $# != 1 ]; then - echo "Usage: utils/data/fix_data_dir.sh " - echo "e.g.: utils/data/fix_data_dir.sh data/train" - echo "This script helps ensure that the various files in a data directory" - echo "are correctly sorted and filtered, for example removing utterances" - echo "that have no features (if feats.scp is present)" - exit 1 -fi - -data=$1 - -if [ -f $data/images.scp ]; then - image/fix_data_dir.sh $cmd - exit $? -fi - -mkdir -p $data/.backup - -[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; - -[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; - -set -e -o pipefail -u - -tmpdir=$(mktemp -d /tmp/kaldi.XXXX); -trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM - -export LC_ALL=C - -function check_sorted { - file=$1 - sort -k1,1 -u <$file >$file.tmp - if ! cmp -s $file $file.tmp; then - echo "$0: file $1 is not in sorted order or not unique, sorting it" - mv $file.tmp $file - else - rm $file.tmp - fi -} - -for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ - reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do - if [ -f $data/$x ]; then - cp $data/$x $data/.backup/$x - check_sorted $data/$x - fi -done - - -function filter_file { - filter=$1 - file_to_filter=$2 - cp $file_to_filter ${file_to_filter}.tmp - utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter - if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then - length1=$(cat ${file_to_filter}.tmp | wc -l) - length2=$(cat ${file_to_filter} | wc -l) - if [ $length1 -ne $length2 ]; then - echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." - fi - fi - rm $file_to_filter.tmp -} - -function filter_recordings { - # We call this once before the stage when we filter on utterance-id, and once - # after. - - if [ -f $data/segments ]; then - # We have a segments file -> we need to filter this and the file wav.scp, and - # reco2file_and_utt, if it exists, to make sure they have the same list of - # recording-ids. - - if [ ! -f $data/wav.scp ]; then - echo "$0: $data/segments exists but not $data/wav.scp" - exit 1; - fi - awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings - n1=$(cat $tmpdir/recordings | wc -l) - [ ! -s $tmpdir/recordings ] && \ - echo "Empty list of recordings (bad file $data/segments)?" && exit 1; - utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp - mv $tmpdir/recordings.tmp $tmpdir/recordings - - - cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments - filter_file $tmpdir/recordings $data/segments - cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments - rm $data/segments.tmp - - filter_file $tmpdir/recordings $data/wav.scp - [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel - [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur - true - fi -} - -function filter_speakers { - # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... - utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt - - cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers - for s in cmvn.scp spk2gender; do - f=$data/$s - if [ -f $f ]; then - filter_file $f $tmpdir/speakers - fi - done - - filter_file $tmpdir/speakers $data/spk2utt - utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk - - for s in cmvn.scp spk2gender $spk_extra_files; do - f=$data/$s - if [ -f $f ]; then - filter_file $tmpdir/speakers $f - fi - done -} - -function filter_utts { - cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts - - ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ - echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; - - ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ - echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ - echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; - - ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ - echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; - - if [ -f $data/utt2uniq ]; then - ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ - echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; - fi - - maybe_wav= - maybe_reco2dur= - [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. - [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts - for x in feats.scp text segments utt2lang $maybe_wav; do - if [ -f $data/$x ]; then - utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp - mv $tmpdir/utts.tmp $tmpdir/utts - fi - done - [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ - rm $tmpdir/utts && exit 1; - - - if [ -f $data/utt2spk ]; then - new_nutts=$(cat $tmpdir/utts | wc -l) - old_nutts=$(cat $data/utt2spk | wc -l) - if [ $new_nutts -ne $old_nutts ]; then - echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" - else - echo "fix_data_dir.sh: kept all $old_nutts utterances." - fi - fi - - for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do - if [ -f $data/$x ]; then - cp $data/$x $data/.backup/$x - if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then - utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x - fi - fi - done - -} - -filter_recordings -filter_speakers -filter_utts -filter_speakers -filter_recordings - -utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt - -echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/swbd/ASR/utils/parse_options.sh b/egs/swbd/ASR/utils/parse_options.sh deleted file mode 100755 index 34476fdb3..000000000 --- a/egs/swbd/ASR/utils/parse_options.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/bin/bash - -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); -# Arnab Ghoshal, Karel Vesely - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# Parse command-line options. -# To be sourced by another script (as in ". parse_options.sh"). -# Option format is: --option-name arg -# and shell variable "option_name" gets set to value "arg." -# The exception is --help, which takes no arguments, but prints the -# $help_message variable (if defined). - - -### -### The --config file options have lower priority to command line -### options, so we need to import them first... -### - -# Now import all the configs specified by command-line, in left-to-right order -for ((argpos=1; argpos<$#; argpos++)); do - if [ "${!argpos}" == "--config" ]; then - argpos_plus1=$((argpos+1)) - config=${!argpos_plus1} - [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 - . $config # source the config file. - fi -done - - -### -### No we process the command line options -### -while true; do - [ -z "${1:-}" ] && break; # break if there are no arguments - case "$1" in - # If the enclosing script is called with --help option, print the help - # message and exit. Scripts should put help messages in $help_message - --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; - else printf "$help_message\n" 1>&2 ; fi; - exit 0 ;; - --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" - exit 1 ;; - # If the first command-line argument begins with "--" (e.g. --foo-bar), - # then work out the variable name as $name, which will equal "foo_bar". - --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; - # Next we test whether the variable in question is undefned-- if so it's - # an invalid option and we die. Note: $0 evaluates to the name of the - # enclosing script. - # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar - # is undefined. We then have to wrap this test inside "eval" because - # foo_bar is itself inside a variable ($name). - eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; - - oldval="`eval echo \\$$name`"; - # Work out whether we seem to be expecting a Boolean argument. - if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then - was_bool=true; - else - was_bool=false; - fi - - # Set the variable to the right value-- the escaped quotes make it work if - # the option had spaces, like --cmd "queue.pl -sync y" - eval $name=\"$2\"; - - # Check that Boolean-valued arguments are really Boolean. - if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then - echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 - exit 1; - fi - shift 2; - ;; - *) break; - esac -done - - -# Check for an empty argument to the --cmd option, which can easily occur as a -# result of scripting errors. -[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; - - -true; # so this script returns exit code 0. diff --git a/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl deleted file mode 100755 index 23992f25d..000000000 --- a/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2011 Microsoft Corporation - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -while(<>){ - @A = split(" ", $_); - @A > 1 || die "Invalid line in spk2utt file: $_"; - $s = shift @A; - foreach $u ( @A ) { - print "$u $s\n"; - } -} - - diff --git a/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl deleted file mode 100755 index 6e0e438ca..000000000 --- a/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2011 Microsoft Corporation - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - -# converts an utt2spk file to a spk2utt file. -# Takes input from the stdin or from a file argument; -# output goes to the standard out. - -if ( @ARGV > 1 ) { - die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; -} - -while(<>){ - @A = split(" ", $_); - @A == 2 || die "Invalid line in utt2spk file: $_"; - ($u,$s) = @A; - if(!$seen_spk{$s}) { - $seen_spk{$s} = 1; - push @spklist, $s; - } - push (@{$spk_hash{$s}}, "$u"); -} -foreach $s (@spklist) { - $l = join(' ',@{$spk_hash{$s}}); - print "$s $l\n"; -} diff --git a/egs/tal_csasr/ASR/README.md b/egs/tal_csasr/ASR/README.md deleted file mode 100644 index a705a2f44..000000000 --- a/egs/tal_csasr/ASR/README.md +++ /dev/null @@ -1,19 +0,0 @@ - -# Introduction - -This recipe includes some different ASR models trained with TAL_CSASR. - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|-----------------------------| -| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| - -The decoder in `transducer_stateless` is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/tal_csasr/ASR/RESULTS.md b/egs/tal_csasr/ASR/RESULTS.md deleted file mode 100644 index e696279bd..000000000 --- a/egs/tal_csasr/ASR/RESULTS.md +++ /dev/null @@ -1,133 +0,0 @@ -## Results - -#### Pruned transducer stateless 7 (zipformer) - -See - -[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe) - -**Note**: The modeling units are byte level BPEs - -The best results I have gotten are: - -Vocab size | greedy (dev & test) | modified beam search (dev & test) | | --- | -- | -- | -- -500 | 6.88 & 6.98 | 6.87 & 6.94 | --epoch 35 --avg 26 - -The training command: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_bbpe/train.py \ - --world-size 4 \ - --start-epoch 1 \ - --num-epochs 35 \ - --use-fp16 1 \ - --max-duration 800 \ - --bbpe-model data/lang_bbpe_500/bbpe.model \ - --exp-dir pruned_transducer_stateless7_bbpe/exp \ - --master-port 12535 -``` - -The decoding command: - -``` - ./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 35 \ - --avg 26 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-sym-per-frame 1 \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --max-duration 2000 \ - --decoding-method greedy_search # modified_beam_search -``` - -The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_tal_csasr_pruned_transducer_stateless7_bbpe - - -### TAL_CSASR Mix Chars and BPEs training results (Pruned Transducer Stateless5) - -#### 2022-06-22 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/428. - -The WERs are - -|decoding-method | epoch(iter) | avg | dev | test | -|--|--|--|--|--| -|greedy_search | 30 | 24 | 7.49 | 7.58| -|modified_beam_search | 30 | 24 | 7.33 | 7.38| -|fast_beam_search | 30 | 24 | 7.32 | 7.42| -|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39| -|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22| -|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.27| -|greedy_search | 348000 | 30 | 7.46 | 7.54| -|modified_beam_search | 348000 | 30 | 7.24 | 7.36| -|fast_beam_search | 348000 | 30 | 7.25 | 7.39 | - -The results (CER(%) and WER(%)) for Chinese CER and English WER respectivly (zh: Chinese, en: English): -|decoding-method | epoch(iter) | avg | dev | dev_zh | dev_en | test | test_zh | test_en | -|--|--|--|--|--|--|--|--|--| -|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| -|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | -|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77| - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" - -./pruned_transducer_stateless5/train.py \ - --world-size 6 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char \ - --max-duration 90 -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/KaACzXOVR0OM6cy0qbN5hw/#scalars - -The decoding command is: -``` -epoch=30 -avg=24 -use_average_model=True - -## greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 800 \ - --use-averaged-model $use_average_model - -## modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 800 \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - --use-averaged-model $use_average_model - -## fast beam search -./pruned_transducer_stateless5/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - --use-averaged-model $use_average_model -``` - -A pre-trained model and decoding logs can be found at diff --git a/egs/tal_csasr/ASR/local/__init__.py b/egs/tal_csasr/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/tal_csasr/ASR/local/compute_fbank_musan.py b/egs/tal_csasr/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/tal_csasr/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py deleted file mode 100755 index 602e50d29..000000000 --- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the tal_csasr dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_tal_csasr(num_mel_bins: int = 80): - src_dir = Path("data/manifests/tal_csasr") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "train_set", - "dev_set", - "test_set", - ) - prefix = "tal_csasr" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_tal_csasr(num_mel_bins=args.num_mel_bins) diff --git a/egs/tal_csasr/ASR/local/display_manifest_statistics.py b/egs/tal_csasr/ASR/local/display_manifest_statistics.py deleted file mode 100644 index 7521bb55b..000000000 --- a/egs/tal_csasr/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -See the function `remove_short_and_long_utt()` -in ../../../librispeech/ASR/transducer/train.py -for usage. -""" - - -from lhotse import load_manifest - - -def main(): - paths = [ - "./data/fbank/tal_csasr_cuts_train_set.jsonl.gz", - "./data/fbank/tal_csasr_cuts_dev_set.jsonl.gz", - "./data/fbank/tal_csasr_cuts_test_set.jsonl.gz", - ] - - for path in paths: - print(f"Displaying the statistics for {path}") - cuts = load_manifest(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Displaying the statistics for ./data/fbank/tal_csasr_cuts_train_set.jsonl.gz -Cuts count: 1050000 -Total duration (hours): 1679.0 -Speech duration (hours): 1679.0 (100.0%) -*** -Duration statistics (seconds): -mean 5.8 -std 4.1 -min 0.3 -25% 2.8 -50% 4.4 -75% 7.3 -99% 18.0 -99.5% 18.8 -99.9% 20.8 -max 36.5 -Displaying the statistics for ./data/fbank/tal_csasr_cuts_dev_set.jsonl.gz -Cuts count: 5000 -Total duration (hours): 8.0 -Speech duration (hours): 8.0 (100.0%) -*** -Duration statistics (seconds): -mean 5.8 -std 4.0 -min 0.5 -25% 2.8 -50% 4.5 -75% 7.4 -99% 17.0 -99.5% 17.7 -99.9% 19.5 -max 21.5 -Displaying the statistics for ./data/fbank/tal_csasr_cuts_test_set.jsonl.gz -Cuts count: 15000 -Total duration (hours): 23.6 -Speech duration (hours): 23.6 (100.0%) -*** -Duration statistics (seconds): -mean 5.7 -std 4.0 -min 0.5 -25% 2.8 -50% 4.4 -75% 7.2 -99% 17.2 -99.5% 17.9 -99.9% 19.6 -max 32.3 -""" diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py deleted file mode 100755 index 499937462..000000000 --- a/egs/tal_csasr/ASR/local/prepare_char.py +++ /dev/null @@ -1,264 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/text_with_bpe, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import re -from pathlib import Path -from typing import Dict, List - -import k2 -import sentencepiece as spm -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: - """Check if all the given tokens are in token symbol table. - - Args: - token_sym_table: - Token symbol table that contains all the valid tokens. - tokens: - A list of tokens. - Returns: - Return True if there is any token not in the token_sym_table, - otherwise False. - """ - for tok in tokens: - if tok not in token_sym_table: - return True - return False - - -def generate_lexicon( - token_sym_table: Dict[str, int], - words: List[str], - bpe_model=None, -) -> Lexicon: - """Generate a lexicon from a word list and token_sym_table. - - Args: - token_sym_table: - Token symbol table that mapping token to token ids. - words: - A list of strings representing words. - Returns: - Return a dict whose keys are words and values are the corresponding - tokens. - """ - sp = "" - if bpe_model is not None: - sp = spm.SentencePieceProcessor() - sp.load(str(bpe_model)) - - lexicon = [] - zhPattern = re.compile(r"([\u4e00-\u9fa5])") - for word in words: - match = zhPattern.search(word) - tokens = [] - if match: - tokens = list(word.strip(" \t")) - else: - tokens = sp.encode_as_pieces(word.strip(" \t")) - - if contain_oov(token_sym_table, tokens): - continue - lexicon.append((word, tokens)) - - # The OOV word is - lexicon.append(("", [""])) - return lexicon - - -def generate_tokens(text_file: str) -> Dict[str, int]: - """Generate tokens from the given text file. - - Args: - text_file: - A file that contains text lines to generate tokens. - Returns: - Return a dict whose keys are tokens and values are token ids ranged - from 0 to len(keys) - 1. - """ - tokens: Dict[str, int] = dict() - tokens[""] = 0 - tokens[""] = 1 - tokens[""] = 2 - whitespace = re.compile(r"([\t\r\n]+)") - with open(text_file, "r", encoding="utf-8") as f: - for line in f: - line = re.sub(whitespace, "", line) - chars = line.split(" ") - for char in chars: - if char not in tokens: - tokens[char] = len(tokens) - - return tokens - - -def main(): - lang_dir = Path("data/lang_char") - text_file = lang_dir / "text_with_bpe" - bpe_model = lang_dir / "bpe.model" - words_file = lang_dir / "words.txt" - - word_sym_table = k2.SymbolTable.from_file(words_file) - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - token_sym_table = generate_tokens(text_file) - - lexicon = generate_lexicon(token_sym_table, words, bpe_model=bpe_model) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py deleted file mode 100755 index c8cf9b881..000000000 --- a/egs/tal_csasr/ASR/local/prepare_lang.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon - -Lexicon = List[Tuple[str, List[str]]] - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") - return parser.parse_args() - - -def main(): - out_dir = Path(get_args().lang_dir) - lexicon_filename = out_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/prepare_words.py b/egs/tal_csasr/ASR/local/prepare_words.py deleted file mode 100755 index 41ab3b2cb..000000000 --- a/egs/tal_csasr/ASR/local/prepare_words.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input words.txt without ids: - - words_no_ids.txt -and generates the new words.txt with related ids. - - words.txt -""" - - -import argparse -import logging - -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare words.txt", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input", - default="data/lang_char/words_no_ids.txt", - type=str, - help="the words file without ids for WenetSpeech", - ) - parser.add_argument( - "--output", - default="data/lang_char/words.txt", - type=str, - help="the words file with ids for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input - output_file = args.output - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - add_words = [" 0", "!SIL 1", " 2", " 3"] - new_lines.extend(add_words) - - logging.info("Starting reading the input file") - for i in tqdm(range(len(lines))): - x = lines[i] - idx = 4 + i - new_line = str(x.strip("\n")) + " " + str(idx) - new_lines.append(new_line) - - logging.info("Starting writing the words.txt") - f_out = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_out.write(line) - f_out.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py deleted file mode 100755 index 74e025ad7..000000000 --- a/egs/tal_csasr/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/text2segments.py b/egs/tal_csasr/ASR/local/text2segments.py deleted file mode 100644 index 3df727c67..000000000 --- a/egs/tal_csasr/ASR/local/text2segments.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text", which refers to the transcript file for -WenetSpeech: - - text -and generates the output file text_word_segmentation which is implemented -with word segmenting: - - text_words_segmentation -""" - - -import argparse - -import jieba -from tqdm import tqdm - -jieba.enable_paddle() - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Chinese Word Segmentation for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - default="data/lang_char/text", - type=str, - help="the input text file for WenetSpeech", - ) - parser.add_argument( - "--output-file", - default="data/lang_char/text_words_segmentation", - type=str, - help="the text implemented with words segmenting for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - for i in tqdm(range(len(lines))): - x = lines[i].rstrip() - seg_list = jieba.cut(x, use_paddle=True) - new_line = " ".join(seg_list) - new_lines.append(new_line) - - f_new = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_new.write(line) - f_new.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py deleted file mode 100755 index 85047c367..000000000 --- a/egs/tal_csasr/ASR/local/text2token.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) -# 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import codecs -import re -import sys -from typing import List - -from pypinyin import lazy_pinyin, pinyin - -is_python2 = sys.version_info[0] == 2 - - -def exist_or_not(i, match_pos): - start_pos = None - end_pos = None - for pos in match_pos: - if pos[0] <= i < pos[1]: - start_pos = pos[0] - end_pos = pos[1] - break - - return start_pos, end_pos - - -def get_parser(): - parser = argparse.ArgumentParser( - description="convert raw text to tokenized text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--nchar", - "-n", - default=1, - type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", - ) - parser.add_argument( - "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" - ) - parser.add_argument("--space", default="", type=str, help="space symbol") - parser.add_argument( - "--non-lang-syms", - "-l", - default=None, - type=str, - help="list of non-linguistic symobles, e.g., etc.", - ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") - parser.add_argument( - "--trans_type", - "-t", - type=str, - default="char", - choices=["char", "pinyin", "lazy_pinyin"], - help="""Transcript type. char/pinyin/lazy_pinyin""", - ) - return parser - - -def token2id( - texts, token_table, token_type: str = "lazy_pinyin", oov: str = "" -) -> List[List[int]]: - """Convert token to id. - Args: - texts: - The input texts, it refers to the chinese text here. - token_table: - The token table is built based on "data/lang_xxx/token.txt" - token_type: - The type of token, such as "pinyin" and "lazy_pinyin". - oov: - Out of vocabulary token. When a word(token) in the transcript - does not exist in the token list, it is replaced with `oov`. - - Returns: - The list of ids for the input texts. - """ - if texts is None: - raise ValueError("texts can't be None!") - else: - oov_id = token_table[oov] - ids: List[List[int]] = [] - for text in texts: - chars_list = list(str(text)) - if token_type == "lazy_pinyin": - text = lazy_pinyin(chars_list) - sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text - ] - ids.append(sub_ids) - else: # token_type = "pinyin" - text = pinyin(chars_list) - sub_ids = [ - token_table[txt[0]] if txt[0] in token_table else oov_id - for txt in text - ] - ids.append(sub_ids) - return ids - - -def main(): - parser = get_parser() - args = parser.parse_args() - - rs = [] - if args.non_lang_syms is not None: - with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: - nls = [x.rstrip() for x in f.readlines()] - rs = [re.compile(re.escape(x)) for x in nls] - - if args.text: - f = codecs.open(args.text, encoding="utf-8") - else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) - - sys.stdout = codecs.getwriter("utf-8")( - sys.stdout if is_python2 else sys.stdout.buffer - ) - line = f.readline() - n = args.nchar - while line: - x = line.split() - print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols :]) # noqa E203 - - # get all matched positions - match_pos = [] - for r in rs: - i = 0 - while i >= 0: - m = r.search(a, i) - if m: - match_pos.append([m.start(), m.end()]) - i = m.end() - else: - break - if len(match_pos) > 0: - chars = [] - i = 0 - while i < len(a): - start_pos, end_pos = exist_or_not(i, match_pos) - if start_pos is not None: - chars.append(a[start_pos:end_pos]) - i = end_pos - else: - chars.append(a[i]) - i += 1 - a = chars - - if args.trans_type == "pinyin": - a = pinyin(list(str(a))) - a = [one[0] for one in a] - - if args.trans_type == "lazy_pinyin": - a = lazy_pinyin(list(str(a))) - - a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203 - - a_flat = [] - for z in a: - a_flat.append("".join(z)) - - a_chars = "".join(a_flat) - print(a_chars) - line = f.readline() - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/text_normalize.py b/egs/tal_csasr/ASR/local/text_normalize.py deleted file mode 100755 index e97b3a5a3..000000000 --- a/egs/tal_csasr/ASR/local/text_normalize.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text_full", which includes all transcript files -for a dataset: - - text_full -and generates the output file text_normalize which is implemented -to normalize text: - - text -""" - - -import argparse - -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Normalizing for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input", - default="data/lang_char/text_full", - type=str, - help="the input text file", - ) - parser.add_argument( - "--output", - default="data/lang_char/text", - type=str, - help="the text implemented with normalizer", - ) - - return parser - - -def text_normalize(str_line: str): - line = str_line.strip().rstrip("\n") - line = line.replace("", "") - line = line.replace("<%>", "") - line = line.replace("<->", "") - line = line.replace("<$>", "") - line = line.replace("<#>", "") - line = line.replace("<_>", "") - line = line.replace("", "") - line = line.replace("`", "") - line = line.replace("'", "") - line = line.replace("&", "") - line = line.replace(",", "") - line = line.replace("A", "A") - line = line.replace("C", "C") - line = line.replace("D", "D") - line = line.replace("E", "E") - line = line.replace("G", "G") - line = line.replace("H", "H") - line = line.replace("I", "I") - line = line.replace("N", "N") - line = line.replace("U", "U") - line = line.replace("W", "W") - line = line.replace("Y", "Y") - line = line.replace("a", "A") - line = line.replace("b", "B") - line = line.replace("c", "C") - line = line.replace("k", "K") - line = line.replace("t", "T") - line = line.replace(",", "") - line = line.replace("丶", "") - line = line.replace("。", "") - line = line.replace("、", "") - line = line.replace("?", "") - line = line.replace("·", "") - line = line.replace("*", "") - line = line.replace("!", "") - line = line.replace("$", "") - line = line.replace("+", "") - line = line.replace("-", "") - line = line.replace("\\", "") - line = line.replace("?", "") - line = line.replace("¥", "") - line = line.replace("%", "") - line = line.replace(".", "") - line = line.replace("<", "") - line = line.replace("&", "") - line = line.replace("~", "") - line = line.replace("=", "") - line = line.replace(":", "") - line = line.replace("!", "") - line = line.replace("/", "") - line = line.replace("‘", "") - line = line.replace("’", "") - line = line.replace("“", "") - line = line.replace("”", "") - line = line.replace("[", "") - line = line.replace("]", "") - line = line.replace("@", "") - line = line.replace("#", "") - line = line.replace(":", "") - line = line.replace(";", "") - line = line.replace("…", "") - line = line.replace("《", "") - line = line.replace("》", "") - line = line.upper() - - return line - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input - output_file = args.output - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - for i in tqdm(range(len(lines))): - new_line = text_normalize(lines[i]) - new_lines.append(new_line) - - f_new = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_new.write(line) - f_new.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py b/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py deleted file mode 100644 index d7fd838f2..000000000 --- a/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright 2021 Mobvoi Inc. (authors: Binbin Zhang) -# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes as input text (it includes Chinese and English): - - text -and generates the text_with_bpe. - - text_with_bpe -""" - - -import argparse -import logging - -import sentencepiece as spm -from tqdm import tqdm - -from icefall.utils import tokenize_by_bpe_model - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare text_with_bpe", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input", - default="data/lang_char/text", - type=str, - help="the text includes Chinese and English words", - ) - parser.add_argument( - "--output", - default="data/lang_char/text_with_bpe", - type=str, - help="the text_with_bpe tokenized by bpe model", - ) - parser.add_argument( - "--bpe-model", - default="data/lang_char/bpe.model", - type=str, - help="the bpe model for processing the English parts", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input - output_file = args.output - bpe_model = args.bpe_model - - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - - logging.info("Starting reading the text") - new_lines = [] - for i in tqdm(range(len(lines))): - x = lines[i] - txt_tokens = tokenize_by_bpe_model(sp, x) - new_line = txt_tokens.replace("/", " ") - new_lines.append(new_line) - - logging.info("Starting writing the text_with_bpe") - f_out = open(output_file, "w", encoding="utf-8") - for line in new_lines: - f_out.write(line) - f_out.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/local/train_bbpe_model.py b/egs/tal_csasr/ASR/local/train_bbpe_model.py deleted file mode 120000 index 7fb4a9f9d..000000000 --- a/egs/tal_csasr/ASR/local/train_bbpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/train_bbpe_model.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh deleted file mode 100755 index 2de4ac8f5..000000000 --- a/egs/tal_csasr/ASR/prepare.sh +++ /dev/null @@ -1,198 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/TALCS_corpus -# You can find three directories:train_set, dev_set, and test_set. -# You can get it from https://ai.100tal.com/dataset -# - dev_set -# - test_set -# - train_set -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bbpe_xxx, -# data/lang_bbpe_yyy if the array contains xxx, yyy -vocab_sizes=( - # 2000 - 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - # Before you run this script, you must get the TAL_CSASR dataset - # from https://ai.100tal.com/dataset - if [ ! -d $dl_dir/tal_csasr/TALCS_corpus ]; then - mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr - fi - - # If you have pre-downloaded it to /path/to/TALCS_corpus, - # you can create a symlink - # - # ln -sfv /path/to/TALCS_corpus $dl_dir/tal_csasr - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/musan - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare tal_csasr manifest" - # We assume that you have downloaded the TALCS_corpus - # to $dl_dir/tal_csasr - if [ ! -f data/manifests/tal_csasr/.manifests.done ]; then - mkdir -p data/manifests/tal_csasr - lhotse prepare tal-csasr $dl_dir/tal_csasr data/manifests/tal_csasr - touch data/manifests/tal_csasr/.manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for tal_csasr" - if [ ! -f data/fbank/.tal_csasr.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_tal_csasr.py - touch data/fbank/.tal_csasr.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare char based lang" - lang_char_dir=data/lang_char - mkdir -p $lang_char_dir - - # Download BPE models trained with LibriSpeech - # Here we use the BPE model with 5000 units trained with Librispeech. - # You can also use other BPE models if available. - if [ ! -f $lang_char_dir/bpe.model ]; then - wget -O $lang_char_dir/bpe.model \ - https://huggingface.co/luomingshuang/bpe_models_trained_with_Librispeech/resolve/main/lang_bpe_500/bpe.model - fi - - # we extract text from manifests rather than the label.txt in corpus, because - # the texts in manifests have been normalized in lhotse. - if [ ! -f $lang_char_dir/text ]; then - gunzip -c data/manifests/tal_csasr/tal_csasr_supervisions_train_set.jsonl.gz \ - | grep -o 'text":\s[^,]*' | sed 's/text": "//g;s/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_train - - gunzip -c data/manifests/tal_csasr/tal_csasr_supervisions_dev_set.jsonl.gz \ - | grep -o 'text":\s[^,]*' | sed 's/text": "//g;s/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_dev - - gunzip -c data/manifests/tal_csasr/tal_csasr_supervisions_test_set.jsonl.gz \ - | grep -o 'text":\s[^,]*' | sed 's/text": "//g;s/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_test - - for r in text_train text_dev text_test ; do - cat $lang_char_dir/$r >> $lang_char_dir/text - done - fi - - # Prepare words.txt - # We assume you have installed jieba, if not, please install - # it using: pip install jieba - if [ ! -f $lang_char_dir/words.txt ]; then - python -m jieba $lang_char_dir/text | sed 's/\///g;s/\s\+/ /g' > $lang_char_dir/text.seg - - (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ - > $lang_char_dir/words.txt - - cat $lang_char_dir/text.seg | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ - | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt - - num_lines=$(< $lang_char_dir/words.txt wc -l) - (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ - >> $lang_char_dir/words.txt - fi - - # Tokenize text with BPE model - python ./local/tokenize_with_bpe_model.py \ - --input $lang_char_dir/text \ - --output $lang_char_dir/text_with_bpe \ - --bpe-model $lang_char_dir/bpe.model - - if [ ! -f $lang_char_dir/L_disambig.pt ]; then - python local/prepare_char.py - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 7: Prepare Byte BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bbpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp $lang_char_dir/words.txt $lang_dir - cp $lang_char_dir/text $lang_dir - - if [ ! -f $lang_dir/bbpe.model ]; then - ./local/train_bbpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/text - fi - done -fi diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/__init__.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py deleted file mode 100644 index 6f0833db6..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ /dev/null @@ -1,425 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class TAL_CSASRAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - - group.add_argument( - "--num-buckets", - type=int, - default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - num_cuts_for_bins_estimate=20000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - rank=0, - world_size=1, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - rank=0, - world_size=1, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "tal_csasr_cuts_train_set.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "tal_csasr_cuts_dev_set.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "tal_csasr_cuts_test_set.jsonl.gz" - ) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/beam_search.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/beam_search.py deleted file mode 120000 index ed78bd4bb..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py deleted file mode 120000 index b2af4e1df..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py deleted file mode 100755 index 3485d4005..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ /dev/null @@ -1,740 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -import re -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import TAL_CSASRAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from local.text_normalize import text_normalize -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, - sp: spm.SentencePieceProcessor = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - zh_hyps = [] - en_hyps = [] - pattern = re.compile(r"([\u4e00-\u9fff])") - en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters - zh_char = "[\u4e00-\u9fa5]+" # Chinese chars - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - chars = pattern.split(hyp.upper()) - chars_new = [] - zh_text = [] - en_text = [] - for char in chars: - if char != "": - tokens = char.strip().split(" ") - chars_new.extend(tokens) - for token in tokens: - zh_text.extend(re.findall(zh_char, token)) - en_text.extend(re.findall(en_letter, token)) - hyps.append(chars_new) - zh_hyps.append(zh_text) - en_hyps.append(en_text) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - chars = pattern.split(hyp.upper()) - chars_new = [] - zh_text = [] - en_text = [] - for char in chars: - if char != "": - tokens = char.strip().split(" ") - chars_new.extend(tokens) - for token in tokens: - zh_text.extend(re.findall(zh_char, token)) - en_text.extend(re.findall(en_letter, token)) - hyps.append(chars_new) - zh_hyps.append(zh_text) - en_hyps.append(en_text) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - chars = pattern.split(hyp.upper()) - chars_new = [] - zh_text = [] - en_text = [] - for char in chars: - if char != "": - tokens = char.strip().split(" ") - chars_new.extend(tokens) - for token in tokens: - zh_text.extend(re.findall(zh_char, token)) - en_text.extend(re.findall(en_letter, token)) - hyps.append(chars_new) - zh_hyps.append(zh_text) - en_hyps.append(en_text) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - chars = pattern.split(hyp.upper()) - chars_new = [] - zh_text = [] - en_text = [] - for char in chars: - if char != "": - tokens = char.strip().split(" ") - chars_new.extend(tokens) - for token in tokens: - zh_text.extend(re.findall(zh_char, token)) - en_text.extend(re.findall(en_letter, token)) - hyps.append(chars_new) - zh_hyps.append(zh_text) - en_hyps.append(en_text) - if params.decoding_method == "greedy_search": - return {"greedy_search": (hyps, zh_hyps, en_hyps)} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): (hyps, zh_hyps, en_hyps) - } - else: - return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, - sp: spm.SentencePieceProcessor = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - zh_results = defaultdict(list) - en_results = defaultdict(list) - pattern = re.compile(r"([\u4e00-\u9fff])") - en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters - zh_char = "[\u4e00-\u9fa5]+" # Chinese chars - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - zh_texts = [] - en_texts = [] - for i in range(len(texts)): - text = texts[i] - chars = pattern.split(text.upper()) - chars_new = [] - zh_text = [] - en_text = [] - for char in chars: - if char != "": - tokens = char.strip().split(" ") - chars_new.extend(tokens) - for token in tokens: - zh_text.extend(re.findall(zh_char, token)) - en_text.extend(re.findall(en_letter, token)) - zh_texts.append(zh_text) - en_texts.append(en_text) - texts[i] = chars_new - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - batch=batch, - sp=sp, - ) - - for name, hyps_texts in hyps_dict.items(): - this_batch = [] - this_batch_zh = [] - this_batch_en = [] - # print(hyps_texts) - hyps, zh_hyps, en_hyps = hyps_texts - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - for cut_id, hyp_words, ref_text in zip(cut_ids, zh_hyps, zh_texts): - this_batch_zh.append((cut_id, ref_text, hyp_words)) - - for cut_id, hyp_words, ref_text in zip(cut_ids, en_hyps, en_texts): - this_batch_en.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - zh_results[name + "_zh"].extend(this_batch_zh) - en_results[name + "_en"].extend(this_batch_en) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results, zh_results, en_results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - TAL_CSASRAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - bpe_model = params.lang_dir + "/bpe.model" - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - def text_normalize_for_cut(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = text.strip("\n").strip("\t") - c.supervisions[0].text = text_normalize(text) - return c - - # we need cut ids to display recognition results. - args.return_cuts = True - tal_csasr = TAL_CSASRAsrDataModule(args) - - dev_cuts = tal_csasr.valid_cuts() - dev_cuts = dev_cuts.map(text_normalize_for_cut) - dev_dl = tal_csasr.valid_dataloaders(dev_cuts) - - test_cuts = tal_csasr.test_cuts() - test_cuts = test_cuts.map(text_normalize_for_cut) - test_dl = tal_csasr.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict, zh_results_dict, en_results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - sp=sp, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - save_results( - params=params, - test_set_name=test_set + "-zh", - results_dict=zh_results_dict, - ) - save_results( - params=params, - test_set_name=test_set + "-en", - results_dict=en_results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decoder.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decoder.py deleted file mode 120000 index 8a5e07bd5..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/encoder_interface.py deleted file mode 120000 index 2fc10439b..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py deleted file mode 100755 index 0f6190a41..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,281 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --tokens ./data/lang_char/tokens.txt \ - --epoch 30 \ - --avg 24 \ - --use-averaged-model True - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless5/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/tal_csasr/ASR - ./pruned_transducer_stateless5/decode.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --epoch 30 \ - --avg 24 \ - --max-duration 800 \ - --decoding-method greedy_search \ - --lang-dir ./data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt.", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - # Load tokens.txt here - token_table = k2.SymbolTable.from_file(params.tokens) - - # Load id of the token and the vocab size - # is defined in local/train_bpe_model.py - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 # +1 for - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/joiner.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/joiner.py deleted file mode 120000 index f31b5fd9b..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py deleted file mode 120000 index b82e115fc..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/model.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/model.py deleted file mode 120000 index be059ba7c..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/optim.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/optim.py deleted file mode 120000 index 661206562..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py deleted file mode 100755 index 8a74ee745..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py +++ /dev/null @@ -1,368 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by -./pruned_transducer_stateless5/export.py -""" - - -import argparse -import logging -import math -import re -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - bpe_model = params.lang_dir + "/bpe.model" - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - lexicon = Lexicon(params.lang_dir) - params.blank_di = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - pattern = re.compile(r"([\u4e00-\u9fff])") - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - chars = pattern.split(hyp.upper()) - chars_new = [] - for char in chars: - if char != "": - chars_new.extend(char.strip().split(" ")) - hyps.append(chars_new) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - chars = pattern.split(hyp.upper()) - chars_new = [] - for char in chars: - if char != "": - chars_new.extend(char.strip().split(" ")) - hyps.append(chars_new) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - chars = pattern.split(hyp.upper()) - chars_new = [] - for char in chars: - if char != "": - chars_new.extend(char.strip().split(" ")) - hyps.append(chars_new) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp]) - chars = pattern.split(hyp.upper()) - chars_new = [] - for char in chars: - if char != "": - chars_new.extend(char.strip().split(" ")) - hyps.append(chars_new) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling.py deleted file mode 120000 index be7b111c6..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py deleted file mode 120000 index db93d155b..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/test_model.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/test_model.py deleted file mode 100755 index 9aad32014..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/test_model.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless4/test_model.py -""" - -from train import get_params, get_transducer_model - - -def test_model_1(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = 24 - params.dim_feedforward = 1536 # 384 * 4 - params.encoder_dim = 384 - model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf -def test_model_M(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = 18 - params.dim_feedforward = 1024 - params.encoder_dim = 256 - params.nhead = 4 - params.decoder_dim = 512 - params.joiner_dim = 512 - model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -def main(): - # test_model_1() - test_model_M() - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py deleted file mode 100755 index c0aedd725..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ /dev/null @@ -1,1084 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import TAL_CSASRAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from local.text_normalize import text_normalize -from local.tokenize_with_bpe_model import tokenize_by_bpe_model -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 100, - "valid_interval": 2000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 1000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts, sep="/") - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - # print(batch["supervisions"]) - - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - bpe_model = params.lang_dir + "/bpe.model" - import sentencepiece as spm - - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - tal_csasr = TAL_CSASRAsrDataModule(args) - train_cuts = tal_csasr.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 20.0 - - def text_normalize_for_cut(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = text.strip("\n").strip("\t") - text = text_normalize(text) - text = tokenize_by_bpe_model(sp, text) - c.supervisions[0].text = text - return c - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.map(text_normalize_for_cut) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = tal_csasr.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = tal_csasr.valid_cuts() - valid_cuts = valid_cuts.map(text_normalize_for_cut) - valid_dl = tal_csasr.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - return - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - TAL_CSASRAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/__init__.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py deleted file mode 120000 index c473a600a..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless5/asr_datamodule.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/beam_search.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/beam_search.py deleted file mode 120000 index 4eef3d295..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless5/beam_search.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decode.py deleted file mode 100755 index 885778965..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decode.py +++ /dev/null @@ -1,815 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7_bbpe/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import TAL_CSASRAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import ( - LmScorer, - NgramLm, - byte_encode, - smart_byte_decode, - tokenize_by_CJK_char, -) -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the byte BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bbpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.25, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - subtract_ilme=True, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - ref_texts = [] - for tx in supervisions["text"]: - ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) - - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(ref_texts), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - key += f"_ilme_scale_{params.ilme_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network LM, used during shallow fusion - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = tokenize_by_CJK_char(ref_text).split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - TAL_CSASRAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - params.suffix += f"-ilme-scale-{params.ilme_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bbpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - tal_csasr = TAL_CSASRAsrDataModule(args) - - test_cuts = tal_csasr.test_cuts() - dev_cuts = tal_csasr.valid_cuts() - - test_dl = tal_csasr.test_dataloaders(test_cuts) - dev_dl = tal_csasr.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decoder.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py deleted file mode 120000 index 083f693ef..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless5/encoder_interface.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/export.py deleted file mode 100755 index 862509d3f..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/export.py +++ /dev/null @@ -1,320 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7_bbpe/decode.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bbpe_500/bbpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/pkufool/icefall_asr_tal_csasr_pruned_transducer_stateless7_bbpe - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/pkufool/icefall_asr_tal_csasr_pruned_transducer_stateless7_bbpe - # You will find the pre-trained model in icefall_asr_tal_csasr_pruned_transducer_stateless7_bbpe/exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_bbpe/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py deleted file mode 100755 index 503cdf4ed..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ /dev/null @@ -1,275 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 20 \ - --avg 10 \ - --jit 1 - -Usage of this script: - -./pruned_transducer_stateless7_bbpe/jit_pretrained.py \ - --nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall import smart_byte_decode - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - x=features, - x_lens=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = smart_byte_decode(sp.decode(hyp)) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/joiner.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/model.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/model.py deleted file mode 120000 index 0d8bc665b..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/optim.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py deleted file mode 100755 index 6e07b5949..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ /dev/null @@ -1,356 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7_bbpe/export.py \ - --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ - --bpe-model data/lang_bbpe_500/bbpe.model \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7_bbpe/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7_bbpe/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7_bbpe/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7_bbpe/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ - --bpe-model ./data/lang_bbpe_500/bbpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7_bbpe/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by -./pruned_transducer_stateless7_bbpe/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import smart_byte_decode -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(smart_byte_decode(hyp).split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/scaling.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/test_model.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/test_model.py deleted file mode 120000 index 7ceac5d10..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/test_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py deleted file mode 100755 index 2108266ec..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ /dev/null @@ -1,1248 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_bbpe/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_bbpe/exp \ - --max-duration 400 - -# For mix precision training: - -./pruned_transducer_stateless7_bbpe/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_bbpe/exp \ - --max-duration 800 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import TAL_CSASRAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut, CutSet -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import byte_encode, diagnostics, tokenize_by_CJK_char -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_bbpe/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bbpe_500/bbpe.model", - help="Path to the Byte BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bbpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - tal_csasr = TAL_CSASRAsrDataModule(args) - train_cuts = tal_csasr.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 12 seconds - # - # Caution: There is a reason to select 12.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - def tokenize_text_in_cut(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = byte_encode(tokenize_by_CJK_char(text)) - c.supervisions[0].text = text - return c - - logging.info(f"Filtering short and long utterances.") - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - logging.info(f"Tokenizing and encoding texts in train cuts.") - train_cuts = train_cuts.map(tokenize_text_in_cut) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = tal_csasr.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = tal_csasr.valid_cuts() - - logging.info(f"Tokenizing and encoding texts in valid cuts.") - valid_cuts = valid_cuts.map(tokenize_text_in_cut) - - valid_dl = tal_csasr.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - TAL_CSASRAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/zipformer.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/zipformer.py deleted file mode 120000 index f2f66041e..000000000 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/shared b/egs/tal_csasr/ASR/shared deleted file mode 120000 index e9461a6d7..000000000 --- a/egs/tal_csasr/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../librispeech/ASR/shared \ No newline at end of file diff --git a/egs/tedlium3/ASR/README.md b/egs/tedlium3/ASR/README.md deleted file mode 100644 index 0740258a7..000000000 --- a/egs/tedlium3/ASR/README.md +++ /dev/null @@ -1,18 +0,0 @@ - -# Introduction - -This recipe includes some different ASR models trained with TedLium3. - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|----------------------------------|-----------|--------------------|-----------------------------| -| `transducer_stateless` | Conformer | Embedding + Conv1d | | -| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | - -The decoder in `transducer_stateless` is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md deleted file mode 100644 index bd8a5b43f..000000000 --- a/egs/tedlium3/ASR/RESULTS.md +++ /dev/null @@ -1,341 +0,0 @@ -## Results - -### TedLium3 BPE training results (Zipformer) - -#### 2023-06-15 (Regular transducer) - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/1125. - -Number of model parameters: 65549011, i.e., 65.5 M - -The WERs are - -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 6.74 | 6.16 | --epoch 50, --avg 22, --max-duration 500 | -| beam search (beam size 4) | 6.56 | 5.95 | --epoch 50, --avg 22, --max-duration 500 | -| modified beam search (beam size 4) | 6.54 | 6.00 | --epoch 50, --avg 22, --max-duration 500 | -| fast beam search (set as default) | 6.91 | 6.28 | --epoch 50, --avg 22, --max-duration 500 | - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./zipformer/train.py \ - --use-fp16 true \ - --world-size 4 \ - --num-epochs 50 \ - --start-epoch 0 \ - --exp-dir zipformer/exp \ - --max-duration 1000 -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/AKXbJha0S9aXyfmuvG4h5A/#scalars - -The decoding command is: -``` -epoch=50 -avg=22 - -## greedy search -./zipformer/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir zipformer/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 500 - -## beam search -./zipformer/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir zipformer/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 500 \ - --decoding-method beam_search \ - --beam-size 4 - -## modified beam search -./zipformer/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir zipformer/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 500 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -## fast beam search -./zipformer/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./zipformer/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -A pre-trained model and decoding logs can be found at - -#### 2023-06-26 (Modified transducer) - -``` -./zipformer/train.py \ - --use-fp16 true \ - --world-size 4 \ - --num-epochs 50 \ - --start-epoch 0 \ - --exp-dir zipformer/exp \ - --max-duration 1000 \ - --rnnt-type modified -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/3d4bYmbJTGiWQQaW88CVEQ/#scalars - -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 6.32 | 5.83 | --epoch 50, --avg 22, --max-duration 500 | -| modified beam search (beam size 4) | 6.16 | 5.79 | --epoch 50, --avg 22, --max-duration 500 | -| fast beam search (set as default) | 6.30 | 5.89 | --epoch 50, --avg 22, --max-duration 500 | - -A pre-trained model and decoding logs can be found at . - -### TedLium3 BPE training results (Conformer-CTC 2) - -#### [conformer_ctc2](./conformer_ctc2) - -See for more details. - -The tensorboard log can be found at - - -You can find a pretrained model and decoding results at: - - -Number of model parameters: 101141699, i.e., 101.14 M - -The WERs are - -| | dev | test | comment | -|--------------------------|------------|-------------|---------------------| -| ctc decoding | 6.45 | 5.96 | --epoch 38 --avg 26 | -| 1best | 5.92 | 5.51 | --epoch 38 --avg 26 | -| whole lattice rescoring | 5.96 | 5.47 | --epoch 38 --avg 26 | -| attention decoder | 5.60 | 5.33 | --epoch 38 --avg 26 | - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./conformer_ctc2/train.py \ - --world-size 4 \ - --num-epochs 40 \ - --exp-dir conformer_ctc2/exp \ - --max-duration 350 \ - --use-fp16 true -``` - -The decoding command is: -``` -epoch=38 -avg=26 - -## ctc decoding -./conformer_ctc2/decode.py \ - --method ctc-decoding \ - --exp-dir conformer_ctc2/exp \ - --lang-dir data/lang_bpe_500 \ - --result-dir conformer_ctc2/exp \ - --max-duration 500 \ - --epoch $epoch \ - --avg $avg - -## 1best -./conformer_ctc2/decode.py \ - --method 1best \ - --exp-dir conformer_ctc2/exp \ - --lang-dir data/lang_bpe_500 \ - --result-dir conformer_ctc2/exp \ - --max-duration 500 \ - --epoch $epoch \ - --avg $avg - -## whole lattice rescoring -./conformer_ctc2/decode.py \ - --method whole-lattice-rescoring \ - --exp-dir conformer_ctc2/exp \ - --lm-path data/lm/G_4_gram_big.pt \ - --lang-dir data/lang_bpe_500 \ - --result-dir conformer_ctc2/exp \ - --max-duration 500 \ - --epoch $epoch \ - --avg $avg - -## attention decoder -./conformer_ctc2/decode.py \ - --method attention-decoder \ - --exp-dir conformer_ctc2/exp \ - --lang-dir data/lang_bpe_500 \ - --result-dir conformer_ctc2/exp \ - --max-duration 500 \ - --epoch $epoch \ - --avg $avg -``` - -### TedLium3 BPE training results (Pruned Transducer) - -#### 2022-03-21 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/261. - -The WERs are - -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 | -| beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 | -| modified beam search (beam size 4) | 6.77 | 6.14 | --epoch 29, --avg 13, --max-duration 100 | -| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500| - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ - --max-duration 300 -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/VpA8b7SZQ7CEjZs9WZ5HNA/#scalars - -The decoding command is: -``` -epoch=29 -avg=13 - -## greedy search -./pruned_transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 - -## beam search -./pruned_transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 - -## modified beam search -./pruned_transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir pruned_transducer_stateless/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -## fast beam search -./pruned_transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -``` - -A pre-trained model and decoding logs can be found at - -### TedLium3 BPE training results (Transducer) - -#### Conformer encoder + embedding decoder - -##### 2022-03-21 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/233 -And the SpecAugment codes from this PR https://github.com/lhotse-speech/lhotse/pull/604 - -Conformer encoder + non-current decoder. The decoder -contains only an embedding layer and a Conv1d (with kernel size 2). - -The WERs are - -| | dev | test | comment | -|------------------------------------|------------|------------|------------------------------------------| -| greedy search | 7.19 | 6.70 | --epoch 29, --avg 11, --max-duration 100 | -| beam search (beam size 4) | 7.02 | 6.36 | --epoch 29, --avg 11, --max-duration 100 | -| modified beam search (beam size 4) | 6.91 | 6.33 | --epoch 29, --avg 11, --max-duration 100 | - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir transducer_stateless/exp \ - --max-duration 300 -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/4ks15jYHR4uMyvpW7Nz76Q/#scalars - -The decoding command is: -``` -epoch=29 -avg=11 - -## greedy search -./transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 - -## beam search -./transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 - -## modified beam search -./transducer_stateless/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir transducer_stateless/exp \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 -``` - -A pre-trained model and decoding logs can be found at diff --git a/egs/tedlium3/ASR/conformer_ctc2/__init__.py b/egs/tedlium3/ASR/conformer_ctc2/__init__.py deleted file mode 100755 index e69de29bb..000000000 diff --git a/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py deleted file mode 120000 index 49b2ee483..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/asr_datamodule.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/attention.py b/egs/tedlium3/ASR/conformer_ctc2/attention.py deleted file mode 100644 index 178cd7e62..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/attention.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2022 Behavox LLC. (author: Daniil Kulko) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple, Union - -import torch -from scaling import ScaledLinear - - -class MultiheadAttention(torch.nn.Module): - """Allows the model to jointly attend to information - from different representation subspaces. This is a modified - version of the original version of multihead attention - (see Attention Is All You Need ) - with replacement of input / output projection layers - with newly introduced ScaleLinear layer - (see https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py). - - Args: - embed_dim: - total dimension of the model. - num_heads: - number of parallel attention heads. Note that embed_dim will be split - across num_heads, i.e. each head will have dimension (embed_dim // num_heads). - dropout: - dropout probability on attn_output_weights. (default=0.0). - bias: - if specified, adds bias to input / output projection layers (default=True). - add_bias_kv: - if specified, adds bias to the key and value sequences at dim=0 (default=False). - add_zero_attn: - if specified, adds a new batch of zeros to the key and value sequences - at dim=1 (default=False). - batch_first: - if True, then the input and output tensors are provided as - (batch, seq, feature), otherwise (seq, batch, feature) (default=False). - - Examples:: - >>> multihead_attn = MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, - add_zero_attn: bool = False, - batch_first: bool = False, - device: Union[torch.device, str, None] = None, - dtype: Union[torch.dtype, str, None] = None, - ) -> None: - - super().__init__() - - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.batch_first = batch_first - - if embed_dim % num_heads != 0: - raise ValueError( - f"embed_dim must be divisible by num_heads. " - "Got embedding dim vs number 0f heads: " - f"{embed_dim} vs {num_heads}" - ) - - self.head_dim = embed_dim // num_heads - - self.in_proj = ScaledLinear( - embed_dim, - 3 * embed_dim, - bias=bias, - device=device, - dtype=dtype, - ) - self.out_proj = ScaledLinear( - embed_dim, - embed_dim, - bias=bias, - initial_scale=0.25, - device=device, - dtype=dtype, - ) - - if add_bias_kv: - self.bias_k = torch.nn.Parameter( - torch.empty((1, 1, embed_dim), device=device, dtype=dtype) - ) - self.bias_v = torch.nn.Parameter( - torch.empty((1, 1, embed_dim), device=device, dtype=dtype) - ) - else: - self.register_parameter("bias_k", None) - self.register_parameter("bias_v", None) - - self.add_zero_attn = add_zero_attn - - self._reset_parameters() - - def _reset_parameters(self) -> None: - if self.bias_k is not None: - torch.nn.init.xavier_normal_(self.bias_k) - if self.bias_v is not None: - torch.nn.init.xavier_normal_(self.bias_v) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_padding_mask: Optional[torch.Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - query: - Query embeddings of shape (L, N, E_q) when batch_first=False or (N, L, E_q) - when batch_first=True, where L is the target sequence length, N is the batch size, - and E_q is the query embedding dimension embed_dim. Queries are compared against - key-value pairs to produce the output. See "Attention Is All You Need" for more details. - key: - Key embeddings of shape (S, N, E_k) when batch_first=False or (N, S, E_k) when - batch_first=True, where S is the source sequence length, N is the batch size, and - E_k is the key embedding dimension kdim. See "Attention Is All You Need" for more details. - value: - Value embeddings of shape (S, N, E_v) when batch_first=False or (N, S, E_v) when - batch_first=True, where S is the source sequence length, N is the batch size, and - E_v is the value embedding dimension vdim. See "Attention Is All You Need" for more details. - key_padding_mask: - If specified, a mask of shape (N, S) indicating which elements within key - to ignore for the purpose of attention (i.e. treat as "padding"). - Binary and byte masks are supported. For a binary mask, a True value indicates - that the corresponding key value will be ignored for the purpose of attention. - For a byte mask, a non-zero value indicates that the corresponding key value will be ignored. - need_weights: - If specifid, returns attn_output_weights in addition to attn_outputs (default=True). - attn_mask: - If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - (L, S) or (N * num_heads, L, S), where N is the batch size, L is the target sequence length, - and S is the source sequence length. A 2D mask will be broadcasted across the batch while - a 3D mask allows for a different mask for each entry in the batch. - Binary, byte, and float masks are supported. For a binary mask, a True value indicates - that the corresponding position is not allowed to attend. For a byte mask, a non-zero - value indicates that the corresponding position is not allowed to attend. For a float mask, - the mask values will be added to the attention weight. - - Returns: - attn_output: - Attention outputs of shape (L, N, E) when batch_first=False or (N, L, E) when batch_first=True, - where L is the target sequence length, N is the batch size, and E is the embedding dimension - embed_dim. - attn_output_weights: - Attention output weights of shape (N, L, S), where N is the batch size, L is the target sequence - length, and S is the source sequence length. Only returned when need_weights=True. - """ - if self.batch_first: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - - ( - attn_output, - attn_output_weights, - ) = torch.nn.functional.multi_head_attention_forward( - query, - key, - value, - self.embed_dim, - self.num_heads, - in_proj_weight=self.in_proj.get_weight(), - in_proj_bias=self.in_proj.get_bias(), - bias_k=self.bias_k, - bias_v=self.bias_v, - add_zero_attn=self.add_zero_attn, - dropout_p=self.dropout, - out_proj_weight=self.out_proj.get_weight(), - out_proj_bias=self.out_proj.get_bias(), - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - if self.batch_first: - return attn_output.transpose(1, 0), attn_output_weights - return attn_output, attn_output_weights diff --git a/egs/tedlium3/ASR/conformer_ctc2/combiner.py b/egs/tedlium3/ASR/conformer_ctc2/combiner.py deleted file mode 100644 index ff526029d..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/combiner.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2022 Behavox LLC. (author: Daniil Kulko) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import torch - - -class RandomCombine(torch.nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - - def __init__( - self, - num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0, - ) -> None: - """ - Args: - num_inputs: - The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: - The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: - The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: - A standard deviation that we add to log-probs for computing - randomized weights. - The method of choosing which layers, or combinations of layers, to use, - is conceptually as follows:: - With probability `pure_prob`:: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else:: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super().__init__() - assert 0 <= pure_prob <= 1, pure_prob - assert 0 < final_weight < 1, final_weight - assert num_inputs >= 1, num_inputs - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev = stddev - - self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) - .log() - .item() - ) - - def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: - """Forward function. - Args: - inputs: - A list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - A Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}" - if not self.training or torch.jit.is_scripting() or len(inputs) == 1: - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape( - (num_frames, num_channels, num_inputs) - ) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights( - inputs[0].dtype, inputs[0].device, num_frames - ) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - - ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) - - return ans - - def _get_random_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> torch.Tensor: - """Return a tensor of random weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired - Returns: - A tensor of shape (num_frames, self.num_inputs), such that - `ans.sum(dim=1)` is all ones. - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where( - torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m - ) - - def _get_random_pure_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> torch.Tensor: - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A one-hot tensor of shape `(num_frames, self.num_inputs)`, with - exactly one weight equal to 1.0 on each frame. - """ - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) - - indexes = torch.where( - torch.rand(num_frames, device=device) < final_prob, final, nonfinal - ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) - return ans - - def _get_random_mixed_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> torch.Tensor: - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A tensor of shape (num_frames, self.num_inputs), which elements - in [0..1] that sum to one over the second axis, i.e. - `ans.sum(dim=1)` is all ones. - """ - logprobs = ( - torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) - * self.stddev - ) - logprobs[:, -1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine( - final_weight: float, - pure_prob: float, - stddev: float, -) -> None: - print( - f"_test_random_combine: final_weight={final_weight}, " - f"pure_prob={pure_prob}, stddev={stddev}" - ) - num_inputs = 3 - num_channels = 50 - m = RandomCombine( - num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev, - ) - - x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] - - y = m(x) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -def _test_random_combine_main() -> None: - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - -if __name__ == "__main__": - _test_random_combine_main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/conformer.py b/egs/tedlium3/ASR/conformer_ctc2/conformer.py deleted file mode 100644 index fad2f371f..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/conformer.py +++ /dev/null @@ -1,1033 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# 2022 Xiaomi Corp. (author: Quandong Wang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -import warnings -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -from combiner import RandomCombine -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv1d, - ScaledLinear, -) -from subsampling import Conv2dSubsampling -from transformer import Supervisions, Transformer, encoder_padding_mask - - -class Conformer(Transformer): - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, - aux_layer_period: int = 3, - ) -> None: - """ - Args: - num_features (int): - number of input features. - num_classes (int): - number of output classes. - subsampling_factor (int): - subsampling factor of encoder; - currently, subsampling_factor MUST be 4. - d_model (int): - attention dimension, also the output dimension. - nhead (int): - number of heads in multi-head attention; - must satisfy d_model // nhead == 0. - dim_feedforward (int): - feedforward dimention. - num_encoder_layers (int): - number of encoder layers. - num_decoder_layers (int): - number of decoder layers. - dropout (float): - dropout rate. - layer_dropout (float): - layer-dropout rate. - cnn_module_kernel (int): - kernel size of convolution module. - aux_layer_period (int): - determines the auxiliary encoder layers. - """ - - super().__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - layer_dropout=layer_dropout, - ) - - self.num_features = num_features - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - layer_dropout=layer_dropout, - cnn_module_kernel=cnn_module_kernel, - ) - - # aux_layers from 1/3 - self.encoder = ConformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - aux_layers=list( - range( - num_encoder_layers // 3, - num_encoder_layers - 1, - aux_layer_period, - ) - ), - ) - - def run_encoder( - self, - x: torch.Tensor, - supervisions: Optional[Supervisions] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - x: - the input tensor. Its shape is (batch_size, seq_len, feature_dim). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - warmup: - a floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - torch.Tensor: Predictor tensor of dimension (S, N, C). - torch.Tensor: Mask tensor of dimension (N, S) - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, S, C) -> (S, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - - x = self.encoder( - x, pos_emb, src_key_padding_mask=mask, warmup=warmup - ) # (S, N, C) - - return x, mask - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - - Examples: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - bypass_scale: float = 0.1, - layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, - ) -> None: - """ - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - bypass_scale: - a scale on the layer's output, used in bypass (resnet-type) skip-connection; - when the layer is bypassed the final output will be a - weighted sum of the layer's input and layer's output with weights - (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1). - layer_dropout: - the probability to bypass the layer (default=0.075). - cnn_module_kernel (int): - kernel size of convolution module (default=31). - """ - super().__init__() - - if bypass_scale < 0.0 or bypass_scale > 1.0: - raise ValueError("bypass_scale should be between 0.0 and 1.0") - - if layer_dropout < 0.0 or layer_dropout > 1.0: - raise ValueError("layer_dropout should be between 0.0 and 1.0") - - self.bypass_scale = bypass_scale - self.layer_dropout = layer_dropout - - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.feed_forward_macaron = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: torch.Tensor, - pos_emb: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - warmup: float = 1.0, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: - the sequence to the encoder layer of shape (S, N, C) (required). - pos_emb: - positional embedding tensor of shape (N, 2*S-1, C) (required). - src_mask: - the mask for the src sequence of shape (S, S) (optional). - src_key_padding_mask: - the mask for the src keys per batch of shape (N, S) (optional). - warmup: - controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - - Returns: - Output tensor of the shape (S, N, C), where - S is the source sequence length, - N is the batch size, - C is the feature number - """ - src_orig = src - - warmup_scale = min(self.bypass_scale + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else self.bypass_scale - ) - else: - alpha = 1.0 - - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - - # multi-headed self-attention module - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - - src = src + self.dropout(src_att) - - # convolution module - src = src + self.dropout(self.conv_module(src)) - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src - - -class ConformerEncoder(nn.Module): - """ - ConformerEncoder is a stack of N encoder layers - - Examples: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - aux_layers: List[int], - ) -> None: - - """ - Args: - encoder_layer: - an instance of the ConformerEncoderLayer() class (required). - num_layers: - the number of sub-encoder-layers in the encoder (required). - aux_layers: - list of indexes of sub-encoder-layers outputs to be combined (required). - """ - - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert len(set(aux_layers)) == len(aux_layers) - - assert num_layers - 1 not in aux_layers - self.aux_layers = aux_layers + [num_layers - 1] - - self.combiner = RandomCombine( - num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, - ) - - def forward( - self, - src: torch.Tensor, - pos_emb: torch.Tensor, - mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - warmup: float = 1.0, - ) -> torch.Tensor: - """ - Pass the input through the encoder layers in turn. - - Args: - src: - the sequence to the encoder of shape (S, N, C) (required). - pos_emb: - positional embedding tensor of shape (N, 2*S-1, C) (required). - mask: - the mask for the src sequence of shape (S, S) (optional). - src_key_padding_mask: - the mask for the src keys per batch of shape (N, S) (optional). - warmup: - controls selective bypass of layer; if < 1.0, we will - bypass the layer more frequently (default=1.0). - - Returns: - Output tensor of the shape (S, N, C), where - S is the source sequence length, - N is the batch size, - C is the feature number. - - """ - output = src - - outputs = [] - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) - - if i in self.aux_layers: - outputs.append(output) - - output = self.combiner(outputs) - - return output - - -class RelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. - - See: Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """ - Construct an PositionalEncoding object. - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - super().__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: torch.Tensor) -> None: - """ - Reset the positional encodings. - - Args: - x: - input tensor (N, T, C), where - T is the source sequence length, - N is the batch size. - C is the feature number. - - """ - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: - """ - Add positional encoding. - - Args: - x: - input tensor (N, T, C). - - Returns: - torch.Tensor: Encoded tensor (N, T, C). - torch.Tensor: Encoded tensor (N, 2*T-1, C), where - T is the source sequence length, - N is the batch size. - C is the feature number. - - """ - self.extend_pe(x) - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - """ - Multi-Head Attention layer with relative position encoding - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context". - - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - """ - Args: - embed_dim: - total dimension of the model. - num_heads: - parallel attention heads. - dropout: - a Dropout layer on attn_output_weights. Default: 0.0. - """ - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear( - embed_dim, embed_dim, bias=True, initial_scale=0.25 - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() - - def _pos_bias_u(self): - return self.pos_bias_u * self.pos_bias_u_scale.exp() - - def _pos_bias_v(self): - return self.pos_bias_v * self.pos_bias_v_scale.exp() - - def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.01) - nn.init.normal_(self.pos_bias_v, std=0.01) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - pos_emb: torch.Tensor, - key_padding_mask: Optional[torch.Tensor] = None, - need_weights: bool = False, - attn_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask - and a value is True, the corresponding value on the attention - layer will be ignored. When given a byte mask and a value is - non-zero, the corresponding value on the attention layer will be ignored. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. - A 2D mask will be broadcasted for all the batches while a 3D - mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), - self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - def rel_shift(self, x: torch.Tensor) -> torch.Tensor: - """ - Compute relative positional encoding. - - Args: - x: - input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - torch.Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - pos_emb: torch.Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: torch.Tensor, - in_proj_bias: torch.Tensor, - dropout_p: float, - out_proj_weight: torch.Tensor, - out_proj_bias: torch.Tensor, - training: bool = True, - key_padding_mask: Optional[torch.Tensor] = None, - need_weights: bool = False, - attn_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. - When the value is True, the corresponding value on the - attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. - A 2D mask will be broadcasted for all the batches while a 3D - mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - f"attn_mask's dimension {attn_mask.dim()} is not supported" - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self._pos_bias_u()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self._pos_bias_v()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: - """ - ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - Construct a ConvolutionModule object. - - Args: - channels (int): - the number of channels of conv layers. - kernel_size (int): - kernerl size of conv layers. - bias (bool): - whether to use bias in conv layers (default=True). - """ - super().__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = ScaledConv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 - ) - - self.depthwise_conv = ScaledConv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channel_dim=1, min_positive=0.05, max_positive=1.0 - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.25, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Compute convolution module. - - Args: - x: - input tensor of shape (T, N, C). - - Returns: - torch.Tensor: Output tensor (T, N, C), where - T is the source sequence length, - N is the batch size, - C is the feature number. - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py deleted file mode 100755 index 28d39de70..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/decode.py +++ /dev/null @@ -1,896 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Quandong Wang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import TedLiumAsrDataModule -from conformer import Conformer -from train import add_model_arguments - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - load_averaged_model, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--method", - type=str, - default="attention-decoder", - help="""Decoding method. - Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (1) ctc-greedy-search. It only use CTC output and a sentence piece - model for decoding. It produces the same results with ctc-decoding. - - (2) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - - (6) attention-decoder. Extract n paths from the LM rescored - lattice, the path with the highest score is the decoding result. - - (7) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="The lang dir", - ) - - parser.add_argument( - "--lm-path", - type=str, - default="data/lm/G_4_gram.pt", - help="""The n-gram LM dir for rescoring. - It should contain either lm_fname.pt or lm_fname.fst.txt - """, - ) - - parser.add_argument( - "--result-dir", - type=str, - default="conformer_ctc2/exp/results", - help="Directory to store results.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - """ - params = AttributeDict( - { - # parameters for conformer - "subsampling_factor": 4, - "feature_dim": 80, - # parameters for decoding - "search_beam": 15, - "output_beam": 8, - "min_active_states": 10, - "max_active_states": 7000, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def ctc_greedy_search( - ctc_probs: torch.Tensor, - mask: torch.Tensor, -) -> List[List[int]]: - """Apply CTC greedy search - Args: - ctc_probs (torch.Tensor): (batch, max_len, num_bpe) - mask (torch.Tensor): (batch, max_len) - Returns: - best path result - """ - - _, max_index = ctc_probs.max(2) # (B, maxlen) - max_index = max_index.masked_fill_(mask, 0) # (B, maxlen) - - ret_hyps = [] - for hyp in max_index: - hyp = torch.unique_consecutive(hyp) - hyp = hyp[hyp > 0].tolist() - ret_hyps.append(hyp) - return ret_hyps - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - - nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - torch.div( - supervisions["start_frame"], - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - supervisions["num_frames"], - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - unk = bpe_model.decode(bpe_model.unk_id()).strip() - hyps = [[w for w in s.split() if w != unk] for s in hyps] - key = "ctc-decoding" - - return {key: hyps} - - if params.method == "ctc-greedy-search": - hyps = ctc_greedy_search(nnet_output, memory_key_padding_mask) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(hyps) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - unk = bpe_model.decode(bpe_model.unk_id()).strip() - hyps = [[w for w in s.split() if w != unk] for s in hyps] - key = "ctc-greedy-search" - - return {key: hyps} - - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [ - [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps - ] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method == "nbest": - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [ - [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps - ] - return {key: hyps} - - assert params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "1best": - best_path_dict = one_best_decoding( - lattice=lattice, - lm_scale_list=lm_scale_list, - ) - elif params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - elif params.method == "attention-decoder": - best_path_dict = rescore_with_attention_decoder( - lattice=lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [ - [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps - ] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - word_table: - It is the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - if hyps_dict is not None: - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - else: - assert len(results) > 0, "It should not decode to empty in the first batch!" - this_batch = [] - hyp_words = [] - for ref_text in texts: - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - for lm_scale in results.keys(): - results[lm_scale].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -) -> None: - if params.method == "attention-decoder": - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.result_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.result_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.result_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main() -> None: - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_path = Path(args.lm_path) - args.result_dir = Path(args.result_dir) - - args.result_dir.mkdir(exist_ok=True) - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id - - if params.method in ("ctc-decoding", "ctc-greedy-search"): - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ("nbest-rescoring", "whole-lattice-rescoring"): - assert params.lm_path.suffix in (".pt", ".txt") - - if params.lm_path.is_file() and params.lm_path.suffix == ".pt": - logging.info(f"Loading pre-compiled {params.lm_path.name}") - d = torch.load(params.lm_path, map_location=device) - G = k2.Fsa.from_dict(d) - elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt": - raise FileNotFoundError(f"No such language model file: '{params.lm_path}'") - else: - # here we pass only if LM filename ends with '.pt' and doesn't exist - # or if LM filename ends '.txt' and exists. - if ( - not params.lm_path.is_file() - and params.lm_path.suffix == ".pt" - and not ( - params.lm_path.parent / f"{params.lm_path.stem}.fst.txt" - ).is_file() - ): - raise FileNotFoundError( - f"No such language model file: '{params.lm_path}'\n" - "'.fst.txt' representation of the language model was " - "not found either." - ) - else: - # whatever params.lm_path.name we got lm_name.pt or lm_name.fst.txt - # we are going to load lm_name.fst.txt here - params.lm_path = params.lm_path.parent / params.lm_path.name.replace( - ".pt", ".fst.txt" - ) - logging.info(f"Loading {params.lm_path.name}") - logging.warning("It may take 8 minutes.") - with open(params.lm_path) as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save( - G.as_dict(), - params.lm_path.parent - / params.lm_path.name.replace(".fst.txt", ".pt"), - ) - - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - model = Conformer( - num_features=params.feature_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - d_model=params.dim_model, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - num_decoder_layers=params.num_decoder_layers, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - tedlium = TedLiumAsrDataModule(args) - - valid_cuts = tedlium.dev_cuts() - test_cuts = tedlium.test_cuts() - - valid_dl = tedlium.valid_dataloaders(valid_cuts) - test_dl = tedlium.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [valid_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -# when we import add_model_arguments from train.py -# we enforce torch.set_num_interop_threads(1) in it, -# so we ended up with setting num_interop_threads to one -# two times: in train.py and decode.py which cause an error, -# that is why added an additional if statement. -if torch.get_num_interop_threads() != 1: - torch.set_num_interop_threads(1) - -# The flag below controls whether to allow TF32 on matmul. This flag defaults to False -# in PyTorch 1.12 and later. -torch.backends.cuda.matmul.allow_tf32 = True - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py deleted file mode 100755 index b5bf911c2..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/export.py +++ /dev/null @@ -1,294 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Behavox LLC (Author: Daniil Kulko) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./conformer_ctc2/export.py \ - --exp-dir ./conformer_ctc2/exp \ - --epoch 20 \ - --avg 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `conformer_ctc2/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/tedlium3/ASR - ./conformer_ctc2/decode.py \ - --exp-dir ./conformer_ctc2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from conformer import Conformer -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import AttributeDict, num_tokens, str2bool - - -def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt.", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=True, - help="""True to save a model after applying torch.jit.script. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - """ - # parameters for conformer - params = AttributeDict({"subsampling_factor": 4, "feature_dim": 80}) - return params - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info(params) - - logging.info("About to create model") - - model = Conformer( - num_features=params.feature_dim, - num_classes=params.vocab_size, - subsampling_factor=params.subsampling_factor, - d_model=params.dim_model, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - num_decoder_layers=params.num_decoder_layers, - ) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - "Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/local b/egs/tedlium3/ASR/conformer_ctc2/local deleted file mode 120000 index c820590c5..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/local +++ /dev/null @@ -1 +0,0 @@ -../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/lstmp.py b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py deleted file mode 120000 index b82e115fc..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/optim.py b/egs/tedlium3/ASR/conformer_ctc2/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling.py b/egs/tedlium3/ASR/conformer_ctc2/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py deleted file mode 120000 index db93d155b..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/subsampling.py b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py deleted file mode 120000 index 8c91f2336..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc2/subsampling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py deleted file mode 100755 index fc3e3b2d9..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/train.py +++ /dev/null @@ -1,1061 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Behavox LLC. (authors: Daniil Kulko) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./conformer_ctc/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir conformer_ctc/exp \ - --max-duration 300 - -# For mix precision training: - -./conformer_ctc/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir conformer_ctc/exp \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -from asr_datamodule import TedLiumAsrDataModule -from conformer import Conformer -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - display_and_save_batch, - encode_supervisions, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--num-decoder-layers", - type=int, - default=6, - help="""Number of decoder layer of transformer decoder. - Setting this to 0 will not create the decoder at all (pure CTC model) - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.8, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward module dimension of the conformer model.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer multiheadattention modules.", - ) - - parser.add_argument( - "--dim-model", - type=int, - default=384, - help="Attention dimension in the conformer model.", - ) - - -def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" and "bpe.model" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="Number of epochs that affects how rapidly the learning rate decreases.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 200, - "valid_interval": 1000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for ctc loss - "beam_size": 10, - "reduction": "none", - "use_double_scores": True, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: torch.nn.Module, - model_avg: torch.nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that is used for training. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[torch.nn.Module, DDP], - model_avg: Optional[torch.nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used for training. - scheduler: - The learning rate scheduler used for training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[torch.nn.Module, DDP], - graph_compiler: BpeCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model( - feature, supervisions, warmup=warmup - ) - - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = convert_texts_into_ids(texts, graph_compiler.sp) - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate > 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - warmup=warmup, - ) - else: - att_loss = torch.tensor([0]) - - ctc_loss_is_finite = torch.isfinite(ctc_loss) - att_loss_is_finite = torch.isfinite(att_loss) - if torch.any(~ctc_loss_is_finite) or torch.any(~att_loss_is_finite): - logging.info( - "Not all losses are finite!\n" - f"ctc_loss: {ctc_loss}\n" - f"att_loss: {att_loss}" - ) - display_and_save_batch(batch, params=params, sp=graph_compiler.sp) - ctc_loss = ctc_loss[ctc_loss_is_finite] - att_loss = att_loss[att_loss_is_finite] - - # If the batch contains more than 10 utterances AND - # if either all ctc_loss or att_loss is inf or nan, - # we stop the training process by raising an exception - if torch.all(~ctc_loss_is_finite) or torch.all(~att_loss_is_finite): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - ctc_loss = ctc_loss.sum() - att_loss = att_loss.sum() - - if params.att_rate > 0.0: - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = ( - torch.div(feature_lens, params.subsampling_factor, rounding_mode="floor") - .sum() - .item() - ) - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate > 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[torch.nn.Module, DDP], - graph_compiler: BpeCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch in valid_dl: - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[torch.nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: BpeCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[torch.nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=graph_compiler.sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - if "lang_bpe" not in str(params.lang_dir): - raise ValueError( - f"Unsupported type of lang dir (we expected it to have " - f"'lang_bpe' in its name): {params.lang_dir}" - ) - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - d_model=params.dim_model, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - num_decoder_layers=params.num_decoder_layers, - ) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[torch.nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = optim.Eve(model.parameters(), lr=params.initial_lr) - scheduler = optim.Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and checkpoints.get("optimizer") is not None: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if checkpoints and checkpoints.get("scheduler") is not None: - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - tedlium = TedLiumAsrDataModule(args) - - train_cuts = tedlium.train_cuts() - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = tedlium.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = tedlium.dev_cuts() - valid_dl = tedlium.valid_dataloaders(valid_cuts) - - if ( - params.start_epoch <= 1 - and params.start_batch <= 0 - and not params.print_diagnostics - ): - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - train_dl.dataset.epoch = epoch - 1 - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[torch.nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=graph_compiler.sp) - raise - - -def main(): - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -# The flag below controls whether to allow TF32 on matmul. This flag defaults to False -# in PyTorch 1.12 and later. -torch.backends.cuda.matmul.allow_tf32 = True - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/transformer.py b/egs/tedlium3/ASR/conformer_ctc2/transformer.py deleted file mode 100644 index 9dbf32e48..000000000 --- a/egs/tedlium3/ASR/conformer_ctc2/transformer.py +++ /dev/null @@ -1,1093 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Copyright 2022 Xiaomi Corp. (author: Quandong Wang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from attention import MultiheadAttention -from combiner import RandomCombine -from label_smoothing import LabelSmoothingLoss -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledEmbedding, - ScaledLinear, -) -from subsampling import Conv2dSubsampling -from torch.nn.utils.rnn import pad_sequence - -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] - - -class Transformer(nn.Module): - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - layer_dropout: float = 0.075, - aux_layer_period: int = 3, - ) -> None: - """ - Args: - num_features: - the input dimension of the model. - num_classes: - the output dimension of the model. - subsampling_factor: - number of output frames is num_in_frames // subsampling_factor; - currently, subsampling_factor MUST be 4. - d_model: - attention dimension. - nhead: - number of heads in multi-head attention; - must satisfy d_model // nhead == 0. - dim_feedforward: - the output dimension of the feedforward layers in encoder/decoder. - num_encoder_layers: - number of encoder layers. - num_decoder_layers: - number of decoder layers. - dropout: - dropout in encoder/decoder. - layer_dropout: - layer-dropout rate. - aux_layer_period: - determines the auxiliary encoder layers. - """ - super().__init__() - - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_classes) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - layer_dropout=layer_dropout, - ) - # aux_layers from 1/3 - self.encoder = TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - aux_layers=list( - range( - num_encoder_layers // 3, - num_encoder_layers - 1, - aux_layer_period, - ) - ), - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True) - ) - - if num_decoder_layers > 0: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol - - self.decoder_embed = ScaledEmbedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - ) - - self.decoder = TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - aux_layers=[], - ) - - self.decoder_output_layer = ScaledLinear( - d_model, self.decoder_num_class, bias=True - ) - - self.decoder_criterion = LabelSmoothingLoss(reduction="none") - else: - self.decoder_criterion = None - - def forward( - self, - x: torch.Tensor, - supervision: Optional[Supervisions] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, S, C). - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - warmup: - a floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is (N, S, C) - - Encoder output with shape (S, N, C). It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is (N, S). - It is None if `supervision` is None. - """ - - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision, warmup - ) - - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - - def run_encoder( - self, - x: torch.Tensor, - supervisions: Optional[Supervisions] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is (N, S, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - warmup: - a floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - Return a tuple with two tensors: - - The encoder output, with shape (S, N, C) - - encoder padding mask, with shape (N, S). - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, S, C) -> (S, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (S, N, C) - - return x, mask - - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - the output tensor from the transformer encoder; - its shape is (S, N, C) - - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is (N, S, C) - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (S, N, C) -> (N, S, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, S, C) - return x - - @torch.jit.export - def decoder_forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - warmup: float = 1.0, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder of shape (S, N, C) - memory_key_padding_mask: - The padding mask from the encoder of shape (N, S). - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - warmup: - a floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - @torch.jit.export - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[torch.Tensor], - sos_id: int, - eos_id: int, - warmup: float = 1.0, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder of shape (S, N, C). - memory_key_padding_mask: - The padding mask from the encoder of shape (N, S). - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - warmup: - a floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - if isinstance(token_ids[0], torch.Tensor): - # This branch is executed by torchscript in C++. - # See https://github.com/k2-fsa/k2/pull/870 - # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 - token_ids = [tolist(t) for t in token_ids] - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, С) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - warmup=warmup, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, - reduction="none", - ) - - nll = nll.view(pred_pad.shape[0], -1) - - return nll - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - - Example: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - bypass_scale: float = 0.1, - layer_dropout: float = 0.075, - ) -> None: - """ - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - bypass_scale: - a scale on the layer's output, used in bypass (resnet-type) skip-connection; - when the layer is bypassed the final output will be a - weighted sum of the layer's input and layer's output with weights - (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1). - layer_dropout: - the probability to bypass the layer (default=0.075). - """ - - super().__init__() - - if bypass_scale < 0.0 or bypass_scale > 1.0: - raise ValueError("bypass_scale should be between 0.0 and 1.0") - - if layer_dropout < 0.0 or layer_dropout > 1.0: - raise ValueError("layer_dropout should be between 0.0 and 1.0") - - self.bypass_scale = bypass_scale - self.layer_dropout = layer_dropout - - self.self_attn = MultiheadAttention(d_model, nhead) - # Implementation of Feedforward model - - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - warmup: float = 1.0, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: - the sequence to the encoder layer of shape (S, N, C) (required). - src_mask: - the mask for the src sequence of shape (S, S) (optional). - src_key_padding_mask: - the mask for the src keys per batch of shape (N, S) (optional) - warmup: - controls selective bypass of layers; if < 1.0, we will - bypass the layer more frequently (default=1.0). - - Returns: - Output tensor of the shape (S, N, C), where - S is the source sequence length, - N is the batch size, - C is the feature number. - - """ - src_orig = src - - warmup_scale = min(self.bypass_scale + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else self.bypass_scale - ) - else: - alpha = 1.0 - - src_att = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = src + self.dropout(src_att) - - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1.0 - alpha) * src_orig - - return src - - -class TransformerDecoderLayer(nn.Module): - """Modified from torch.nn.TransformerDecoderLayer. - - Example: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - bypass_scale: float = 0.1, - layer_dropout: float = 0.075, - ) -> None: - - """ - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - bypass_scale: - a scale on the layer's output, used in bypass (resnet-type) skip-connection; - when the layer is bypassed, the final output will be a - weighted sum of the layer's input and layer's output with weights - (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1). - layer_dropout: - the probability to bypass the layer (default=0.075). - """ - - super().__init__() - - if bypass_scale < 0.0 or bypass_scale > 1.0: - raise ValueError("bypass_scale should be between 0.0 and 1.0") - - if layer_dropout < 0.0 or layer_dropout > 1.0: - raise ValueError("layer_dropout should be between 0.0 and 1.0") - - self.bypass_scale = bypass_scale - self.layer_dropout = layer_dropout - - self.self_attn = MultiheadAttention(d_model, nhead) - self.src_attn = MultiheadAttention(d_model, nhead) - - # Implementation of Feedforward model - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - - self.dropout = nn.Dropout(dropout) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - warmup: float = 1.0, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - tgt: - the sequence to the decoder layer of shape (T, N, C) (required). - memory: - the sequence from the last layer of the encoder of shape (S, N, C) (required). - tgt_mask: - the mask for the tgt sequence of shape (T, T) (optional). - memory_mask: - the mask for the memory sequence of shape (T, S) (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch of shape (N, T) (optional). - memory_key_padding_mask: - the mask for the memory keys per batch of shape (N, S) (optional). - warmup: controls selective bypass of layers; if < 1.0, we will - bypass the layer more frequently (default=1.0). - - Returns: - Output tensor of the shape (T, N, C), where - S is the source sequence length, - T is the target sequence length, - N is the batch size, - C is the feature number. - - """ - tgt_orig = tgt - - warmup_scale = min(self.bypass_scale + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else self.bypass_scale - ) - else: - alpha = 1.0 - - tgt_att = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, - )[0] - tgt = tgt + self.dropout(tgt_att) - - src_att = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = tgt + self.dropout(src_att) - - tgt = tgt + self.dropout(self.feed_forward(tgt)) - - tgt = self.norm_final(self.balancer(tgt)) - - if alpha != 1.0: - tgt = alpha * tgt + (1.0 - alpha) * tgt_orig - - return tgt - - -class TransformerEncoder(nn.Module): - """TransformerEncoder is a stack of N encoder layers - - Examples: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = transformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - aux_layers: List[int], - ) -> None: - """ - Args: - encoder_layer: - an instance of the TransformerEncoderLayer() class (required). - num_layers: - the number of sub-encoder-layers in the encoder (required). - aux_layers: - list of indexes of sub-encoder-layers outputs to be combined (required). - """ - - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert len(set(aux_layers)) == len(aux_layers) - - assert num_layers - 1 not in aux_layers - self.aux_layers = aux_layers + [num_layers - 1] - - self.combiner = RandomCombine( - num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, - ) - - def forward( - self, - src: torch.Tensor, - mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - warmup: float = 1.0, - ) -> torch.Tensor: - """Pass the input through the encoder layers in turn. - - Args: - src: - the input to the encoder of shape (S, N, C) (required). - mask: - the mask for the src sequence of shape (S, S) (optional). - src_key_padding_mask: - the mask for the src keys per batch of shape (N, S) (optional). - warmup: - controls selective bypass of layer; if < 1.0, we will - bypass the layer more frequently (default=1.0). - - Returns: - Output tensor of the shape (S, N, C), where - S is the source sequence length, - N is the batch size, - C is the feature number. - - """ - output = src - - outputs = [] - for i, mod in enumerate(self.layers): - output = mod( - output, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) - - if i in self.aux_layers: - outputs.append(output) - - output = self.combiner(outputs) - - return output - - -class TransformerDecoder(nn.Module): - """TransformerDecoder is a stack of N decoder layers - - Examples: - >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) - >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = transformer_decoder(tgt, memory) - """ - - def __init__( - self, - decoder_layer: nn.Module, - num_layers: int, - aux_layers: List[int], - ) -> None: - """ - Args: - decoder_layer: - an instance of the TransformerDecoderLayer() class (required). - num_layers: - the number of decoder layers in the decoder (required). - aux_layers: - list of indexes of decoder layer outputs to be combined (required). - """ - - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(decoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert len(set(aux_layers)) == len(aux_layers) - - assert num_layers - 1 not in aux_layers - self.aux_layers = aux_layers + [num_layers - 1] - - self.combiner = RandomCombine( - num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, - ) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - warmup: float = 1.0, - ) -> torch.Tensor: - """Pass the input (and mask) through the decoder layers in turn. - - Args: - tgt: - the sequence to the decoder of shape (T, N, C) (required). - memory: - the sequence from the last layer of the encoder of shape (S, N, C) (required). - tgt_mask: - the mask for the tgt sequence of shape (T, T) (optional). - memory_mask: - the mask for the memory sequence of shape (T, S) (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch of shape (N, T) (optional). - memory_key_padding_mask: - the mask for the memory keys per batch of shape (N, S) (optional). - warmup: - controls selective bypass of layer; if < 1.0, we will - bypass the layer more frequently (default=1.0). - - Returns: - Output tensor of the shape (T, N, C), where - S is the source sequence length, - T is the target sequence length, - N is the batch size, - C is the feature number. - - """ - output = tgt - - outputs = [] - for i, mod in enumerate(self.layers): - output = mod( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - warmup=warmup, - ) - - if i in self.aux_layers: - outputs.append(output) - - output = self.combiner(outputs) - - return output - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: Embedding dimension. - dropout: Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - T is the target sequence length, - N is the batch size, - C is the feature number. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: Input of shape is (N, T, C) - - Returns: - A tensor of the same shape (N, T, C), - T is the target sequence length, - N is the batch size, - C is the feature number. - - """ - self.extend_pe(x) - x = x + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Mask tensor of dimension (batch_size, input_length), - True denotes the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask - - -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - A bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask tensor of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - return [[sos_id] + utt for utt in token_ids] - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-lists of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-lists, where each sublist ends - with EOS ID. - """ - return [utt + [eos_id] for utt in token_ids] - - -def tolist(t: torch.Tensor) -> List[int]: - """Used by jit""" - return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/tedlium3/ASR/local/__init__.py b/egs/tedlium3/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/tedlium3/ASR/local/compile_hlg.py b/egs/tedlium3/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/tedlium3/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/local/compute_fbank_musan.py b/egs/tedlium3/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/tedlium3/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py deleted file mode 100755 index 733ebf235..000000000 --- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file computes fbank features of the TedLium3 dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_tedlium(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "train", - "dev", - "test", - ) - - prefix = "tedlium" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cur_num_jobs = num_jobs if ex is None else 80 - cur_num_jobs = min(cur_num_jobs, len(cut_set)) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=cur_num_jobs, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - # Split long cuts into many short and un-overlapping cuts - cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_tedlium() diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py deleted file mode 100644 index 19ba8d24b..000000000 --- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -""" -Convert a transcript based on words to a list of BPE ids. - -For example, if we use 2 as the encoding id of -Note: it, inserts a space token before each - -texts = ['this is a day'] -spm_ids = [[38, 33, 6, 15, 2, 316]] - -texts = [' this is a sunny day'] -spm_ids = [[15, 2, 38, 33, 6, 118, 11, 11, 21, 316]] - -texts = [''] -spm_ids = [[15, 2]] - -""" - -import argparse -import logging -from typing import List - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--texts", type=List[str], help="The input transcripts list.") - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - return parser.parse_args() - - -def convert_texts_into_ids( - texts: List[str], - sp: spm.SentencePieceProcessor, -) -> List[List[int]]: - """ - Args: - texts: - A string list of transcripts, such as ['Today is Monday', 'It's sunny']. - sp: - A sentencepiece BPE model. - Returns: - Return an integer list of bpe ids. - """ - y = [] - for text in texts: - if "" in text: - id_segments = sp.encode(text.split(""), out_type=int) - - y_ids = [] - for i in range(len(id_segments)): - y_ids += id_segments[i] - if i < len(id_segments) - 1: - y_ids += [sp.piece_to_id("▁"), sp.unk_id()] - else: - y_ids = sp.encode(text, out_type=int) - y.append(y_ids) - - return y - - -def main(): - args = get_args() - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - y = convert_texts_into_ids(texts=args.texts, sp=sp) - - logging.info(f"The input texts: {args.texts}") - logging.info(f"The encoding ids: {y}") - - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/local/display_manifest_statistics.py b/egs/tedlium3/ASR/local/display_manifest_statistics.py deleted file mode 100755 index 52e152389..000000000 --- a/egs/tedlium3/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` -in ../../../librispeech/ASR/transducer/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - path = "./data/fbank/tedlium_cuts_train.jsonl.gz" - path = "./data/fbank/tedlium_cuts_dev.jsonl.gz" - path = "./data/fbank/tedlium_cuts_test.jsonl.gz" - - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -## train -Cuts count: 804789 -Total duration (hours): 1370.6 -Speech duration (hours): 1370.6 (100.0%) -*** -Duration statistics (seconds): -mean 6.1 -std 3.1 -min 0.5 -25% 3.7 -50% 6.0 -75% 8.3 -99.5% 14.9 -99.9% 16.6 -max 33.3 - -## dev -Cuts count: 507 -Total duration (hours): 1.6 -Speech duration (hours): 1.6 (100.0%) -*** -Duration statistics (seconds): -mean 11.3 -std 5.7 -min 0.5 -25% 7.5 -50% 10.6 -75% 14.4 -99.5% 29.8 -99.9% 37.7 -max 39.9 - -## test -Cuts count: 1155 -Total duration (hours): 2.6 -Speech duration (hours): 2.6 (100.0%) -*** -Duration statistics (seconds): -mean 8.2 -std 4.3 -min 0.3 -25% 4.6 -50% 8.2 -75% 10.9 -99.5% 22.1 -99.9% 26.7 -max 32.5 -""" diff --git a/egs/tedlium3/ASR/local/prepare_lang_bpe.py b/egs/tedlium3/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/tedlium3/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py deleted file mode 100755 index d4ccdd1e3..000000000 --- a/egs/tedlium3/ASR/local/prepare_transcripts.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (author: Mingshuang Luo) -# Copyright 2022 Behavox LLC. (author: Daniil Kulko) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes input text file and removes all words -that iclude any character out of English alphabet. - -""" -import argparse -import logging -import re -from pathlib import Path - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--input-text-path", - type=str, - help="Input text file path.", - ) - parser.add_argument( - "--output-text-path", - type=str, - help="Output text file path.", - ) - - return parser.parse_args() - - -def prepare_transcripts(input_text_path: Path, output_text_path: Path) -> None: - """ - Args: - input_text_path: - The input data text file path, e.g., data/lang/train_orig.txt. - output_text_path: - The output data text file path, e.g., data/lang/train.txt. - - Return: - Saved text file in output_text_path. - """ - - foreign_chr_check = re.compile(r"[^a-z']") - - logging.info(f"Loading {input_text_path.name}") - with open(input_text_path, "r", encoding="utf8") as f: - texts = {t.rstrip("\n") for t in f} - - texts = { - " ".join([w for w in t.split() if foreign_chr_check.search(w) is None]) - for t in texts - } - - with open(output_text_path, "w+", encoding="utf8") as f: - for t in texts: - f.write(f"{t}\n") - - -def main() -> None: - args = get_args() - input_text_path = Path(args.input_text_path) - output_text_path = Path(args.output_text_path) - - logging.info(f"Generating {output_text_path.name}") - prepare_transcripts(input_text_path, output_text_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/tedlium3/ASR/local/prepare_words.py b/egs/tedlium3/ASR/local/prepare_words.py deleted file mode 100755 index a37d0f08f..000000000 --- a/egs/tedlium3/ASR/local/prepare_words.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Behavox LLC. (authors: Daniil Kulko) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input supervisions json dir "data/manifests" -consisting of tedlium_supervisions_train.json and does the following: - -1. Generate words.txt. - -""" -import argparse -import logging -import re -from pathlib import Path - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="Output directory.", - ) - - return parser.parse_args() - - -def prepare_words(lang_dir: str) -> None: - """ - Args: - lang_dir: - The language directory, e.g., data/lang. - - Return: - The words.txt file. - """ - - words_orig_path = Path(lang_dir) / "words_orig.txt" - words_path = Path(lang_dir) / "words.txt" - - foreign_chr_check = re.compile(r"[^a-z']") - - logging.info(f"Loading {words_orig_path.name}") - with open(words_orig_path, "r", encoding="utf8") as f: - words = {w for w_compl in f for w in w_compl.strip("-\n").split("_")} - words = {w for w in words if foreign_chr_check.search(w) is None and w != ""} - words.add("") - words = ["", "!SIL"] + sorted(words) + ["#0", "", ""] - - with open(words_path, "w+", encoding="utf8") as f: - for idx, word in enumerate(words): - f.write(f"{word} {idx}\n") - - -def main() -> None: - args = get_args() - lang_dir = Path(args.lang_dir) - - logging.info("Generating words.txt") - prepare_words(lang_dir) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/tedlium3/ASR/local/train_bpe_model.py b/egs/tedlium3/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/tedlium3/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh deleted file mode 100755 index 2f58ca0ee..000000000 --- a/egs/tedlium3/ASR/prepare.sh +++ /dev/null @@ -1,212 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=0 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/tedlium3 -# You can find data, doc, legacy, LM, etc, inside it. -# You can download them from https://www.openslr.org/51 -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - 5000 - 2000 - 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/tedlium3, - # you can create a symlink - # - # ln -sfv /path/to/tedlium3 $dl_dir/tedlium3 - # - if [ ! -d $dl_dir/tedlium3 ]; then - lhotse download tedlium $dl_dir - mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3 - fi - - # Download big and small 4 gram lanuage models - if [ ! -d $dl_dir/lm ]; then - wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P $dl_dir/lm - wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P $dl_dir/lm - gzip -d $dl_dir/lm/4gram_small.arpa.gz $dl_dir/lm/4gram_big.arpa.gz - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - #ln -sfv /path/to/musan $dl_dir/musan - - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare tedlium3 manifests" - if [ ! -f data/manifests/.tedlium3.done ]; then - # We assume that you have downloaded the tedlium3 corpus - # to $dl_dir/tedlium3 - mkdir -p data/manifests - lhotse prepare tedlium $dl_dir/tedlium3 data/manifests - touch data/manifests/.tedlium3.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifests" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -e data/manifests/.musan.done ]; then - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for tedlium3" - - if [ ! -e data/fbank/.tedlium3.done ]; then - mkdir -p data/fbank - - python3 ./local/compute_fbank_tedlium.py - - gunzip -c data/fbank/tedlium_cuts_train.jsonl.gz | shuf | \ - gzip -c > data/fbank/tedlium_cuts_train-shuf.jsonl.gz - mv data/fbank/tedlium_cuts_train-shuf.jsonl.gz \ - data/fbank/tedlium_cuts_train.jsonl.gz - - touch data/fbank/.tedlium3.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -e data/fbank/.musan.done ]; then - mkdir -p data/fbank - python3 ./local/compute_fbank_musan.py - touch data/fbank/.musan.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare BPE train data and set of words" - lang_dir=data/lang - mkdir -p $lang_dir - - if [ ! -f $lang_dir/train.txt ]; then - gunzip -c $dl_dir/tedlium3/LM/*.en.gz | sed 's: <\/s>::g' > $lang_dir/train_orig.txt - - ./local/prepare_transcripts.py \ - --input-text-path $lang_dir/train_orig.txt \ - --output-text-path $lang_dir/train.txt - fi - - if [ ! -f $lang_dir/words.txt ]; then - - awk '{print $1}' $dl_dir/tedlium3/TEDLIUM.152k.dic | - sed 's:([0-9])::g' | sort | uniq > $lang_dir/words_orig.txt - - ./local/prepare_words.py --lang-dir $lang_dir - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang/words.txt $lang_dir - - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript data/lang/train.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir --oov "" - fi - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - if [ ! -f data/lm/G_4_gram_small.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - --max-arpa-warnings=-1 \ - $dl_dir/lm/4gram_small.arpa > data/lm/G_4_gram_small.fst.txt - fi - - if [ ! -f data/lm/G_4_gram_big.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - --max-arpa-warnings=-1 \ - $dl_dir/lm/4gram_big.arpa > data/lm/G_4_gram_big.fst.txt - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compile HLG" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/HLG.pt ]; then - ./local/compile_hlg.py \ - --lang-dir $lang_dir \ - --lm G_4_gram_small - fi - done -fi diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/__init__.py b/egs/tedlium3/ASR/pruned_transducer_stateless/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py deleted file mode 120000 index 49b2ee483..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/asr_datamodule.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py deleted file mode 120000 index 7f9f6263f..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/conformer.py b/egs/tedlium3/ASR/pruned_transducer_stateless/conformer.py deleted file mode 120000 index 8be0dc864..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py deleted file mode 100755 index abba9d403..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ /dev/null @@ -1,519 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless/decode.py \ - --epoch 29 \ - --avg 13 \ - --exp-dir ./pruned_transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) beam search -./pruned_transducer_stateless/decode.py \ - --epoch 29 \ - --avg 13 \ - --exp-dir ./pruned_transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless/decode.py \ - --epoch 29 \ - --avg 13 \ - --exp-dir ./pruned_transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless/decode.py \ - --epoch 29 \ - --avg 13 \ - --exp-dir ./pruned_transducer_stateless/exp \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import TedLiumAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=13, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - batch=batch, - decoding_graph=decoding_graph, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - tedlium = TedLiumAsrDataModule(args) - dev_cuts = tedlium.dev_cuts() - test_cuts = tedlium.test_cuts() - - dev_dl = tedlium.valid_dataloaders(dev_cuts) - test_dl = tedlium.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py deleted file mode 120000 index 206384eaa..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/encoder_interface.py b/egs/tedlium3/ASR/pruned_transducer_stateless/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py deleted file mode 100644 index aa22f82ec..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless/export.py \ - --exp-dir ./pruned_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 29 \ - --avg 13 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/tedlium3/ASR - ./pruned_transducer_stateless/decode.py \ - --exp-dir ./pruned_transducer_stateless/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 1 \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=13, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/joiner.py b/egs/tedlium3/ASR/pruned_transducer_stateless/joiner.py deleted file mode 120000 index b3d677eb5..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless/joiner.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/local b/egs/tedlium3/ASR/pruned_transducer_stateless/local deleted file mode 120000 index c820590c5..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/local +++ /dev/null @@ -1 +0,0 @@ -../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/model.py b/egs/tedlium3/ASR/pruned_transducer_stateless/model.py deleted file mode 120000 index 6b78aed54..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless/model.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py deleted file mode 100644 index 9e58fed00..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ /dev/null @@ -1,353 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - --max-sym-per-frame 1 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by -./pruned_transducer_stateless/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --method is beam_search and modified_beam_search ", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - hyps = [] - msg = f"Using {params.decoding_method}" - logging.info(msg) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/subsampling.py b/egs/tedlium3/ASR/pruned_transducer_stateless/subsampling.py deleted file mode 120000 index fd7ca8b30..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/test_decoder.py b/egs/tedlium3/ASR/pruned_transducer_stateless/test_decoder.py deleted file mode 100755 index b97bf6150..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/test_decoder.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/tedlium3/ASR - python ./pruned_transducer_stateless/test_decoder.py -""" - -import torch -from decoder import Decoder - - -def test_decoder(): - vocab_size = 3 - blank_id = 0 - unk_id = 2 - embedding_dim = 128 - context_size = 4 - - decoder = Decoder( - vocab_size=vocab_size, - embedding_dim=embedding_dim, - blank_id=blank_id, - unk_id=unk_id, - context_size=context_size, - ) - N = 100 - U = 20 - x = torch.randint(low=0, high=vocab_size, size=(N, U)) - y = decoder(x) - assert y.shape == (N, U, vocab_size) - - # for inference - x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) - y = decoder(x, need_pad=False) - assert y.shape == (N, 1, vocab_size) - - -def main(): - test_decoder() - - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py deleted file mode 100755 index 2455f3630..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py +++ /dev/null @@ -1,767 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ - --max-duration 300 -""" - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import TedLiumAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids -from model import Transducer -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12350, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - attention_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, - # parameters for Noam - "warm_step": 80000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.vocab_size, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - unk_id=params.unk_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - unk_id = params.unk_id - y = convert_texts_into_ids(texts, sp=sp) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - tedlium = TedLiumAsrDataModule(args) - - train_cuts = tedlium.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 17 seconds - return 1.0 <= c.duration <= 17.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_dl = tedlium.train_dataloaders(train_cuts) - valid_cuts = tedlium.dev_cuts() - valid_dl = tedlium.valid_dataloaders(valid_cuts) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/transformer.py b/egs/tedlium3/ASR/pruned_transducer_stateless/transformer.py deleted file mode 120000 index 214afed39..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/transformer.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/shared b/egs/tedlium3/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/tedlium3/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/README.md b/egs/tedlium3/ASR/transducer_stateless/README.md deleted file mode 100644 index 9b6ed62f1..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## Introduction - -The decoder, i.e., the prediction network, is from -https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 -(Rnn-Transducer with Stateless Prediction Network) - -You can use the following command to start the training: - -```bash -cd egs/tedlium3/ASR - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir transducer_stateless/exp \ - --max-duration 300 -``` diff --git a/egs/tedlium3/ASR/transducer_stateless/__init__.py b/egs/tedlium3/ASR/transducer_stateless/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py deleted file mode 100644 index a67cf8d04..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ /dev/null @@ -1,368 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class TedLiumAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. TEDLium3 dev - and test). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset.", - ) - - def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=10, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - max_frames_mask_fraction=0.15, - p=0.9, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to get Musan cuts") - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create train dataset") - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - else: - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - logging.info("About to create train dataloader") - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts_test: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - if self.args.on_the_fly_feats: - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - test = K2SpeechRecognitionDataset( - return_cuts=self.args.return_cuts, - ) - - test_sampler = DynamicBucketingSampler( - cuts_test, - max_duration=self.args.max_duration, - shuffle=False, - ) - - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "tedlium_cuts_train.jsonl.gz" - ) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz") diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py deleted file mode 100644 index 1f99edaf3..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py +++ /dev/null @@ -1,539 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional - -import torch -from model import Transducer - - -def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - - device = model.device - - decoder_input = torch.tensor( - [blank_id] * context_size, device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - encoder_out_len = torch.tensor([1]) - decoder_out_len = torch.tensor([1]) - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out, encoder_out_len, decoder_out_len - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y != blank_id and y != unk_id: - hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - return hyp - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def run_decoder( - ys: List[int], - model: Transducer, - decoder_cache: Dict[str, torch.Tensor], -) -> torch.Tensor: - """Run the neural decoder model for a given hypothesis. - - Args: - ys: - The current hypothesis. - model: - The transducer model. - decoder_cache: - Cache to save computations. - Returns: - Return a 1-D tensor of shape (decoder_out_dim,) containing - output of `model.decoder`. - """ - context_size = model.decoder.context_size - key = "_".join(map(str, ys[-context_size:])) - if key in decoder_cache: - return decoder_cache[key] - - device = model.device - - decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_cache[key] = decoder_out - - return decoder_out - - -def run_joiner( - key: str, - model: Transducer, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - encoder_out_len: torch.Tensor, - decoder_out_len: torch.Tensor, - joint_cache: Dict[str, torch.Tensor], -): - """Run the joint network given outputs from the encoder and decoder. - - Args: - key: - A key into the `joint_cache`. - model: - The transducer model. - encoder_out: - A tensor of shape (1, 1, encoder_out_dim). - decoder_out: - A tensor of shape (1, 1, decoder_out_dim). - encoder_out_len: - A tensor with value [1]. - decoder_out_len: - A tensor with value [1]. - joint_cache: - A dict to save computations. - Returns: - Return a tensor from the output of log-softmax. - Its shape is (vocab_size,). - """ - if key in joint_cache: - return joint_cache[key] - - logits = model.joiner( - encoder_out, - decoder_out, - encoder_out_len, - decoder_out_len, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - - joint_cache[key] = log_prob - - return log_prob - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - - device = model.device - - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - encoder_out_len = torch.tensor([1]) - decoder_out_len = torch.tensor([1]) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # current_encoder_out is of shape (1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - # decoder_output is of shape (num_hyps, 1, decoder_output_dim) - - current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) - - logits = model.joiner( - current_encoder_out, - decoder_out, - encoder_out_len.expand(decoder_out.size(0)), - decoder_out_len.expand(decoder_out.size(0)), - ) - # logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[i] - if new_token != blank_id and new_token != unk_id: - new_ys.append(new_token) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - return ys - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - - device = model.device - - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - encoder_out_len = torch.tensor([1]) - decoder_out_len = torch.tensor([1]) - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - decoder_out = run_decoder( - ys=y_star.ys, model=model, decoder_cache=decoder_cache - ) - - key = "_".join(map(str, y_star.ys[-context_size:])) - key += f"-t-{t}" - log_prob = run_joiner( - key=key, - model=model, - encoder_out=current_encoder_out, - decoder_out=decoder_out, - encoder_out_len=encoder_out_len, - decoder_out_len=decoder_out_len, - joint_cache=joint_cache, - ) - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for idx in range(values.size(0)): - i = indices[idx].item() - if i == blank_id or i == unk_id: - continue - - new_ys = y_star.ys + [i] - - new_log_prob = y_star.log_prob + values[idx] - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys diff --git a/egs/tedlium3/ASR/transducer_stateless/conformer.py b/egs/tedlium3/ASR/transducer_stateless/conformer.py deleted file mode 120000 index 8be0dc864..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py deleted file mode 100755 index fb0e3116b..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ /dev/null @@ -1,486 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./transducer_stateless/decode.py \ - --epoch 29 \ - --avg 11 \ - --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) beam search -./transducer_stateless/decode.py \ - --epoch 29 \ - --avg 11 \ - --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./transducer_stateless/decode.py \ - --epoch 29 \ - --avg 11 \ - --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Tuple - -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import TedLiumAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=13, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""Used only when --decoding-method is - beam_search or modified_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=3, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - unk_id=params.unk_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - else: - return {f"beam_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - tedlium = TedLiumAsrDataModule(args) - dev_cuts = tedlium.dev_cuts() - test_cuts = tedlium.test_cuts() - - dev_dl = tedlium.valid_dataloaders(dev_cuts) - test_dl = tedlium.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py deleted file mode 100644 index f9a3814c6..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/decoder.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - unk_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - unk_id: - The ID of the unk symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - self.unk_id = unk_id - assert context_size >= 1, context_size - self.context_size = context_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, - kernel_size=context_size, - padding=0, - groups=embedding_dim, - bias=False, - ) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, embedding_dim). - """ - embedding_out = self.embedding(y) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - return embedding_out diff --git a/egs/tedlium3/ASR/transducer_stateless/encoder_interface.py b/egs/tedlium3/ASR/transducer_stateless/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py deleted file mode 100644 index 48dcdc736..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/export.py +++ /dev/null @@ -1,252 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./transducer_stateless/export.py \ - --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 29 \ - --avg 11 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `transducer_stateless/decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/tedlium3/ASR - ./transducer_stateless/decode.py \ - --exp-dir ./transducer_stateless/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - unk_id=params.unk_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tedlium3/ASR/transducer_stateless/joiner.py b/egs/tedlium3/ASR/transducer_stateless/joiner.py deleted file mode 120000 index 1aec6bfaf..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/joiner.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/local b/egs/tedlium3/ASR/transducer_stateless/local deleted file mode 120000 index c820590c5..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/local +++ /dev/null @@ -1 +0,0 @@ -../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/model.py b/egs/tedlium3/ASR/transducer_stateless/model.py deleted file mode 120000 index 16ddd93f0..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/model.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py deleted file mode 100644 index 5300fe764..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ /dev/null @@ -1,338 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./transducer_stateless/pretrained.py \ - --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - --max-sym-per-frame 1 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./transducer_stateless/pretrained.py \ - --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./transducer_stateless/pretrained.py \ - --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./transducer_stateless/exp/epoch-xx.pt`. - -Note: ./transducer_stateless/exp/pretrained.pt is generated by -./transducer_stateless/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torch.nn as nn -import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer -from torch.nn.utils.rnn import pad_sequence - -from icefall.env import get_env_info -from icefall.utils import AttributeDict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --method is beam_search and modified_beam_search ", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=3, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - unk_id=params.unk_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/tedlium3/ASR/transducer_stateless/subsampling.py b/egs/tedlium3/ASR/transducer_stateless/subsampling.py deleted file mode 120000 index fd7ca8b30..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/test_decoder.py b/egs/tedlium3/ASR/transducer_stateless/test_decoder.py deleted file mode 100755 index cc5f64951..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/test_decoder.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/tedlium3/ASR - python ./transducer_stateless/test_decoder.py -""" - -import torch -from decoder import Decoder - - -def test_decoder(): - vocab_size = 3 - blank_id = 0 - unk_id = 2 - embedding_dim = 128 - context_size = 4 - - decoder = Decoder( - vocab_size=vocab_size, - embedding_dim=embedding_dim, - blank_id=blank_id, - unk_id=unk_id, - context_size=context_size, - ) - N = 100 - U = 20 - x = torch.randint(low=0, high=vocab_size, size=(N, U)) - y = decoder(x) - assert y.shape == (N, U, embedding_dim) - - # for inference - x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) - y = decoder(x, need_pad=False) - assert y.shape == (N, 1, embedding_dim) - - -def main(): - test_decoder() - - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py deleted file mode 100755 index c6fa34e70..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ /dev/null @@ -1,737 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir transducer_stateless/exp \ - --max-duration 300 -""" - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import TedLiumAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids -from model import Transducer -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lr-factor", - type=float, - default=5.0, - help="The lr_factor for Noam optimizer", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--modified-transducer-prob", - type=float, - default=0.25, - help="""The probability to use modified transducer loss. - In modified transduer, it limits the maximum number of symbols - per frame to 1. See also the option --max-sym-per-frame in - transducer_stateless/decode.py - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - attention_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for Noam - "warm_step": 80000, # For the 100h subset, use 8k - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - unk_id=params.unk_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - unk_id = params.unk_id - y = convert_texts_into_ids(texts, sp=sp) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - loss = model( - x=feature, - x_lens=feature_lens, - y=y, - modified_transducer_prob=params.modified_transducer_prob, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - tedlium = TedLiumAsrDataModule(args) - - train_cuts = tedlium.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 17 seconds - return 1.0 <= c.duration <= 17.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_dl = tedlium.train_dataloaders(train_cuts) - valid_cuts = tedlium.dev_cuts() - valid_dl = tedlium.valid_dataloaders(valid_cuts) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/transducer_stateless/transformer.py b/egs/tedlium3/ASR/transducer_stateless/transformer.py deleted file mode 120000 index 214afed39..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/transformer.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/__init__.py b/egs/tedlium3/ASR/zipformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/tedlium3/ASR/zipformer/asr_datamodule.py b/egs/tedlium3/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index 49b2ee483..000000000 --- a/egs/tedlium3/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../transducer_stateless/asr_datamodule.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/beam_search.py b/egs/tedlium3/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/tedlium3/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/decode.py b/egs/tedlium3/ASR/zipformer/decode.py deleted file mode 100755 index 2c4123c20..000000000 --- a/egs/tedlium3/ASR/zipformer/decode.py +++ /dev/null @@ -1,833 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import TedLiumAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - unk = sp.decode(sp.unk_id()).strip() - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - allow_partial=True, - ) - for hyp in sp.decode(hyp_tokens): - hyp = [w for w in hyp.split() if w != unk] - hyps.append(hyp) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - allow_partial=True, - ) - for hyp in hyp_tokens: - hyp = [word_table[i] for i in hyp if word_table[i] != unk] - hyps.append(hyp) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - allow_partial=True, - ) - for hyp in sp.decode(hyp_tokens): - hyp = [w for w in hyp.split() if w != unk] - hyps.append(hyp) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - allow_partial=True, - ) - for hyp in sp.decode(hyp_tokens): - hyp = [w for w in hyp.split() if w != unk] - hyps.append(hyp) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyp = [w for w in hyp.split() if w != unk] - hyps.append(hyp) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyp = [w for w in hyp.split() if w != unk] - hyps.append(hyp) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp = [w for w in sp.decode(hyp).split() if w != unk] - hyps.append(hyp) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - tedlium = TedLiumAsrDataModule(args) - - dev_cuts = tedlium.dev_cuts() - test_cuts = tedlium.test_cuts() - - dev_dl = tedlium.test_dataloaders(dev_cuts) - test_dl = tedlium.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dls = [dev_dl, test_dl] - - for name, dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=name, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/zipformer/decoder.py b/egs/tedlium3/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/tedlium3/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/encoder_interface.py b/egs/tedlium3/ASR/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/tedlium3/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/export.py b/egs/tedlium3/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/tedlium3/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/joiner.py b/egs/tedlium3/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/tedlium3/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/local b/egs/tedlium3/ASR/zipformer/local deleted file mode 120000 index c820590c5..000000000 --- a/egs/tedlium3/ASR/zipformer/local +++ /dev/null @@ -1 +0,0 @@ -../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/model.py b/egs/tedlium3/ASR/zipformer/model.py deleted file mode 100644 index 65b052ab9..000000000 --- a/egs/tedlium3/ASR/zipformer/model.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos, make_pad_mask - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dim) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder_embed = encoder_embed - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, - vocab_size, - initial_scale=0.25, - ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, - vocab_size, - initial_scale=0.25, - ) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - rnnt_type: str = "regular", - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - rnnt_type: - The type of label topology to use for the transducer loss. One of "regular", - "modified", or "constrained". - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - x, x_lens = self.encoder_embed(x, x_lens) - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - rnnt_type=rnnt_type, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - rnnt_type=rnnt_type, - ) - - return (simple_loss, pruned_loss) diff --git a/egs/tedlium3/ASR/zipformer/my_profile.py b/egs/tedlium3/ASR/zipformer/my_profile.py deleted file mode 120000 index 3a90b2628..000000000 --- a/egs/tedlium3/ASR/zipformer/my_profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/optim.py b/egs/tedlium3/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/tedlium3/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/pretrained.py b/egs/tedlium3/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/tedlium3/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/scaling.py b/egs/tedlium3/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/tedlium3/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/scaling_converter.py b/egs/tedlium3/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/tedlium3/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/subsampling.py b/egs/tedlium3/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/tedlium3/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py deleted file mode 100755 index 14a44efb3..000000000 --- a/egs/tedlium3/ASR/zipformer/train.py +++ /dev/null @@ -1,1307 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --full-libri 1 \ - --max-duration 1000 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import TedLiumAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids -from model import Transducer -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=50, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.04, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--rnnt-type", - type=str, - default="regular", - choices=["regular", "modified", "constrained"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=1, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNNT loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = convert_texts_into_ids(texts, sp) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - rnnt_type=params.rnnt_type, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - tedlium = TedLiumAsrDataModule(args) - - train_cuts = tedlium.train_cuts() - train_cuts = train_cuts.filter(lambda c: 1.0 <= c.duration <= 20.0) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = tedlium.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = tedlium.dev_cuts() - valid_dl = tedlium.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - TedLiumAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/tedlium3/ASR/zipformer/zipformer.py b/egs/tedlium3/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/tedlium3/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/timit/ASR/README.md b/egs/timit/ASR/README.md deleted file mode 100644 index f700fab9e..000000000 --- a/egs/timit/ASR/README.md +++ /dev/null @@ -1,3 +0,0 @@ - -Please refer to -for how to run models in this recipe. diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md deleted file mode 100644 index d8ceb82b6..000000000 --- a/egs/timit/ASR/RESULTS.md +++ /dev/null @@ -1,74 +0,0 @@ -## Results - -### TIMIT training results (Tdnn_LSTM_CTC) -#### 2021-11-16 -(Mingshuang Luo): Result of https://github.com/k2-fsa/icefall/pull/114 - -TensorBoard log is available at https://tensorboard.dev/experiment/qhA1o025Q322kO34SlhWzg/#scalars - -Pretrained model is available at https://huggingface.co/luomingshuang/icefall_asr_timit_tdnn_lstm_ctc - -The best decoding results (PER) are listed below, we got this results by averaging models from epoch 16 to 25, and using `whole-lattice-rescoring` with lm_scale equals to 0.08. - -||TEST| -|--|--| -|PER| 19.71% | - -You can use the following commands to reproduce our results: - -```bash -git clone https://github.com/k2-fsa/icefall -cd icefall - -cd egs/timit/ASR -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0" -python tdnn_lstm_ctc/train.py --bucketing-sampler True \ - --concatenate-cuts False \ - --max-duration 200 \ - --world-size 1 \ - --lang-dir data/lang_phone - -python tdnn_lstm_ctc/decode.py --epoch 25 \ - --avg 10 \ - --max-duration 20 \ - --lang-dir data/lang_phone -``` - -### TIMIT training results (Tdnn_LiGRU_CTC) -#### 2021-11-16 - -(Mingshuang Luo): Result of phone based Tdnn_LiGRU_CTC model, https://github.com/k2-fsa/icefall/pull/114 - -TensorBoard log is available at https://tensorboard.dev/experiment/IlQxeq5vQJ2SEVP94Y5fyg/#scalars - -Pretrained model is available at https://huggingface.co/luomingshuang/icefall_asr_timit_tdnn_ligru_ctc - -The best decoding results (PER) are listed below, we got this results by averaging models from epoch 9 to 25, and using `whole-lattice-rescoring` decoding method with lm_scale equals to 0.1. - -||TEST| -|--|--| -|PER| 17.66% | - -You can use the following commands to reproduce our results: - -```bash -git clone https://github.com/k2-fsa/icefall -cd icefall - -cd egs/timit/ASR -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0" -python tdnn_ligru_ctc/train.py --bucketing-sampler True \ - --concatenate-cuts False \ - --max-duration 200 \ - --world-size 1 \ - --lang-dir data/lang_phone - -python tdnn_ligru_ctc/decode.py --epoch 25 \ - --avg 17 \ - --max-duration 20 \ - --lang-dir data/lang_phone -``` diff --git a/egs/timit/ASR/local/__init__.py b/egs/timit/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py deleted file mode 100755 index c8562f4fb..000000000 --- a/egs/timit/ASR/local/compile_hlg.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_3_gram.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str) -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path("data/lm/G.pt").is_file(): - logging.info("Loading pre-compiled G") - d = torch.load("data/lm/G.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/timit/ASR/local/compute_fbank_musan.py b/egs/timit/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/timit/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py deleted file mode 100755 index ecdf10ba9..000000000 --- a/egs/timit/ASR/local/compute_fbank_timit.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the TIMIT dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_timit(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "TRAIN", - "DEV", - "TEST", - ) - prefix = "timit" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}" - if cuts_file.is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if partition == "TRAIN": - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(cuts_file) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_timit() diff --git a/egs/timit/ASR/local/prepare_lang.py b/egs/timit/ASR/local/prepare_lang.py deleted file mode 100755 index e9f283274..000000000 --- a/egs/timit/ASR/local/prepare_lang.py +++ /dev/null @@ -1,386 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon -from icefall.utils import str2bool - -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - Generated files by this script are saved into this directory. - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - """, - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - - sorted_ans = list(ans) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - pronprob = 1.0 - score = -math.log(pronprob) - - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, score]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - tokens[i] = tokens[i] if i >= 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - lexicon_filename = lang_dir / "lexicon.txt" - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - - words = get_words(lexicon) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(lang_dir / "tokens.txt", token2id) - write_mapping(lang_dir / "words.txt", word2id) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(lang_dir / "L.png", title="L") - L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig") - - -if __name__ == "__main__": - main() diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py deleted file mode 100755 index 0cf0f0deb..000000000 --- a/egs/timit/ASR/local/prepare_lexicon.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input supervisions json dir "data/manifests" -consisting of supervisions_TRAIN.json and does the following: - -1. Generate lexicon.txt. - -""" -import argparse -import json -import logging -from pathlib import Path - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--manifests-dir", - type=str, - help="""Input directory. - """, - ) - parser.add_argument( - "--lang-dir", - type=str, - help="""Output directory. - """, - ) - - return parser.parse_args() - - -def prepare_lexicon(manifests_dir: str, lang_dir: str): - """ - Args: - manifests_dir: - The manifests directory, e.g., data/manifests. - lang_dir: - The language directory, e.g., data/lang_phone. - - Return: - The lexicon.txt file and the train.text in lang_dir. - """ - import gzip - - phones = set() - - supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz" - lexicon = Path(lang_dir) / "lexicon.txt" - - logging.info(f"Loading {supervisions_train}!") - with gzip.open(supervisions_train, "r") as load_f: - for line in load_f.readlines(): - load_dict = json.loads(line) - text = load_dict["text"] - # list the phone units and filter the empty item - phones_list = list(filter(None, text.split())) - - for phone in phones_list: - if phone not in phones: - phones.add(phone) - - with open(lexicon, "w") as f: - for phone in sorted(phones): - f.write(phone + " " + phone) - f.write("\n") - f.write(" ") - f.write("\n") - - -def main(): - args = get_args() - manifests_dir = Path(args.manifests_dir) - lang_dir = Path(args.lang_dir) - - logging.info("Generating lexicon.txt") - prepare_lexicon(manifests_dir, lang_dir) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh deleted file mode 100644 index f25fe5add..000000000 --- a/egs/timit/ASR/prepare.sh +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -num_phones=39 -# Here we use num_phones=39 for modeling - -nj=15 -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/timit -# You can find data, train_data.csv, test_data.csv, etc, inside it. -# You can download them from https://data.deepai.org/timit.zip -# -# - $dl_dir/lm -# This directory contains the language model(LM) downloaded from -# https://huggingface.co/luomingshuang/timit_lm, and the LM is based -# on 39 phones. About how to get these LM files, you can know it -# from https://github.com/luomingshuang/Train_LM_with_kaldilm. -# -# - lm_3_gram.arpa -# - lm_4_gram.arpa -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech -dl_dir=$PWD/download -splits_dir=$PWD/splits_dir - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download LM" - # We assume that you have installed the git-lfs, if not, you could install it - # using: `sudo apt-get install git-lfs && git-lfs install` - [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm - git clone https://huggingface.co/luomingshuang/timit_lm $dl_dir/lm - pushd $dl_dir/lm - git lfs pull - popd -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/timit, - # you can create a symlink - # - # ln -sfv /path/to/timit $dl_dir/timit - # - if [ ! -d $dl_dir/timit ]; then - lhotse download timit $dl_dir - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare timit manifest" - # We assume that you have downloaded the timit corpus - # to $dl_dir/timit - mkdir -p data/manifests - lhotse prepare timit -p $num_phones -j $nj $dl_dir/timit/data data/manifests -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for timit" - mkdir -p data/fbank - ./local/compute_fbank_timit.py -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - mkdir -p data/fbank - ./local/compute_fbank_musan.py -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" - lang_dir=data/lang_phone - mkdir -p $lang_dir - - ./local/prepare_lexicon.py \ - --manifests-dir data/manifests \ - --lang-dir $lang_dir - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $dl_dir/lm/lm_3_gram.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $dl_dir/lm/lm_4_gram.arpa > data/lm/G_4_gram.fst.txt - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone -fi diff --git a/egs/timit/ASR/shared b/egs/timit/ASR/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/timit/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/timit/ASR/tdnn_ligru_ctc/__init__.py b/egs/timit/ASR/tdnn_ligru_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/timit/ASR/tdnn_ligru_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_ligru_ctc/asr_datamodule.py deleted file mode 120000 index fa1b8cca3..000000000 --- a/egs/timit/ASR/tdnn_ligru_ctc/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py deleted file mode 100644 index 4beeed18c..000000000 --- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py +++ /dev/null @@ -1,488 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import TimitAsrDataModule -from model import TdnnLiGRU - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=19, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - parser.add_argument( - "--method", - type=str, - default="whole-lattice-rescoring", - help="""Decoding method. - Supported values are: - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("tdnn_ligru_ctc/exp/"), - "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), - "feature_dim": 80, - "subsampling_factor": 2, - "search_beam": 20, - "output_beam": 5, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - batch: dict, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = HLG.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - feature = feature.permute(0, 2, 1) # now feature is (N, C, T) - - nnet_output = model(feature) - # nnet_output is (N, T, C) - - supervisions = batch["supervisions"] - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-{params.num_paths}" - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] - - lm_scale_list = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09] - lm_scale_list += [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - else: - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - - ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - lexicon=lexicon, - G=G, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out PERs, per-phone error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"per-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tPER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, PER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - TimitAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) - HLG = HLG.to(device) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") - G = k2.Fsa.from_dict(d).to(device) - - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - model = TdnnLiGRU( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") - return - - model.to(device) - model.eval() - - # we need cut ids to display recognition results. - args.return_cuts = True - timit = TimitAsrDataModule(args) - test_set = "TEST" - test_dl = timit.test_dataloaders() - - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - lexicon=lexicon, - G=G, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py deleted file mode 100644 index 9a594a969..000000000 --- a/egs/timit/ASR/tdnn_ligru_ctc/model.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional - -import torch -import torch.nn as nn -from torch import Tensor - - -class TdnnLiGRU(nn.Module): - def __init__( - self, num_features: int, num_classes: int, subsampling_factor: int = 3 - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - It reduces the number of output frames by this factor. - """ - super().__init__() - - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=512, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - nn.Conv1d( - in_channels=512, - out_channels=512, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - nn.Conv1d( - in_channels=512, - out_channels=512, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - nn.Conv1d( - in_channels=512, - out_channels=512, - kernel_size=3, - stride=self.subsampling_factor, # stride: subsampling_factor! - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - ) - self.ligrus = nn.ModuleList( - [ - LiGRU( - input_shape=[None, None, 512], - hidden_size=512, - num_layers=1, - bidirectional=True, - ) - for _ in range(4) - ] - ) - self.linears = nn.ModuleList( - [nn.Linear(in_features=1024, out_features=512) for _ in range(4)] - ) - self.bnorms = nn.ModuleList( - [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(4)] - ) - self.dropout = nn.Dropout(0.2) - self.linear = nn.Linear(in_features=512, out_features=self.num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Its shape is [N, C, T] - - Returns: - The output tensor has shape [N, T, C] - """ - x = self.tdnn(x) - x = x.permute(0, 2, 1) - for ligru, linear, bnorm in zip(self.ligrus, self.linears, self.bnorms): - x_new, _ = ligru(x) - x_new = linear(x_new) - x_new = bnorm(x_new.permute(0, 2, 1)).permute(0, 2, 1) - # (N, T, C) -> (N, C, T) -> (N, T, C) - x_new = self.dropout(x_new) - x = x_new + x # skip connections - - x = self.linear(x) - x = nn.functional.log_softmax(x, dim=-1) - return x - - -class LiGRU(torch.nn.Module): - """This function implements a Light GRU (liGRU). - This LiGRU model is from speechbrain, please see - https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/RNN.py - - LiGRU is single-gate GRU model based on batch-norm + relu - activations + recurrent dropout. For more info see: - - "M. Ravanelli, P. Brakel, M. Omologo, Y. Bengio, - Light Gated Recurrent Units for Speech Recognition, - in IEEE Transactions on Emerging Topics in Computational Intelligence, - 2018" (https://arxiv.org/abs/1803.10225) - - This is a custm RNN and to speed it up it must be compiled with - the torch just-in-time compiler (jit) right before using it. - You can compile it with: - compiled_model = torch.jit.script(model) - - It accepts in input tensors formatted as (batch, time, fea). - In the case of 4d inputs like (batch, time, fea, channel) the tensor is - flattened as (batch, time, fea*channel). - - Arguments - --------- - hidden_size : int - Number of output neurons (i.e, the dimensionality of the output). - values (i.e, time and frequency kernel sizes respectively). - input_shape : tuple - The shape of an example input. - nonlinearity : str - Type of nonlinearity (tanh, relu). - normalization : str - Type of normalization for the ligru model (batchnorm, layernorm). - Every string different from batchnorm and layernorm will result - in no normalization. - num_layers : int - Number of layers to employ in the RNN architecture. - bias : bool - If True, the additive bias b is adopted. - dropout : float - It is the dropout factor (must be between 0 and 1). - bidirectional : bool - If True, a bidirectional model that scans the sequence both - right-to-left and left-to-right is used. - - Example - ------- - >>> inp_tensor = torch.rand([4, 10, 20]) - >>> net = LiGRU(input_shape=inp_tensor.shape, hidden_size=5) - >>> out_tensor, _ = net(inp_tensor) - >>> - torch.Size([4, 10, 5]) - """ - - def __init__( - self, - hidden_size, - input_shape, - nonlinearity="relu", - normalization="batchnorm", - num_layers=1, - bias=True, - dropout=0.0, - bidirectional=False, - ): - super().__init__() - self.hidden_size = hidden_size - self.nonlinearity = nonlinearity - self.num_layers = num_layers - self.normalization = normalization - self.bias = bias - self.dropout = dropout - self.bidirectional = bidirectional - self.reshape = False - - # Computing the feature dimensionality - if len(input_shape) > 3: - self.reshape = True - self.fea_dim = float(torch.prod(torch.tensor(input_shape[2:]))) - self.batch_size = input_shape[0] - self.rnn = self._init_layers() - - def _init_layers(self): - """Initializes the layers of the liGRU.""" - rnn = torch.nn.ModuleList([]) - current_dim = self.fea_dim - - for i in range(self.num_layers): - rnn_lay = LiGRU_Layer( - current_dim, - self.hidden_size, - self.num_layers, - self.batch_size, - dropout=self.dropout, - nonlinearity=self.nonlinearity, - normalization=self.normalization, - bidirectional=self.bidirectional, - ) - rnn.append(rnn_lay) - - if self.bidirectional: - current_dim = self.hidden_size * 2 - else: - current_dim = self.hidden_size - return rnn - - def forward(self, x, hx: Optional[Tensor] = None): - """Returns the output of the liGRU. - - Arguments - --------- - x : torch.Tensor - The input tensor. - hx : torch.Tensor - Starting hidden state. - """ - # Reshaping input tensors for 4d inputs - if self.reshape: - if x.ndim == 4: - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) - - # run ligru - output, hh = self._forward_ligru(x, hx=hx) - - return output, hh - - def _forward_ligru(self, x, hx: Optional[Tensor]): - """Returns the output of the vanilla liGRU. - - Arguments - --------- - x : torch.Tensor - Input tensor. - hx : torch.Tensor - """ - h = [] - if hx is not None: - if self.bidirectional: - hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size) - # Processing the different layers - for i, ligru_lay in enumerate(self.rnn): - if hx is not None: - x = ligru_lay(x, hx=hx[i]) - else: - x = ligru_lay(x, hx=None) - h.append(x[:, -1, :]) - h = torch.stack(h, dim=1) - - if self.bidirectional: - h = h.reshape(h.shape[1] * 2, h.shape[0], self.hidden_size) - else: - h = h.transpose(0, 1) - - return x, h - - -class LiGRU_Layer(torch.nn.Module): - """This function implements Light-Gated Recurrent Units (ligru) layer. - - Arguments - --------- - input_size : int - Feature dimensionality of the input tensors. - batch_size : int - Batch size of the input tensors. - hidden_size : int - Number of output neurons. - num_layers : int - Number of layers to employ in the RNN architecture. - nonlinearity : str - Type of nonlinearity (tanh, relu). - normalization : str - Type of normalization (batchnorm, layernorm). - Every string different from batchnorm and layernorm will result - in no normalization. - dropout : float - It is the dropout factor (must be between 0 and 1). - bidirectional : bool - if True, a bidirectional model that scans the sequence both - right-to-left and left-to-right is used. - """ - - def __init__( - self, - input_size, - hidden_size, - num_layers, - batch_size, - dropout=0.0, - nonlinearity="relu", - normalization="batchnorm", - bidirectional=False, - ): - - super(LiGRU_Layer, self).__init__() - self.hidden_size = int(hidden_size) - self.input_size = int(input_size) - self.batch_size = batch_size - self.bidirectional = bidirectional - self.dropout = dropout - self.drop = torch.nn.Dropout(p=self.dropout, inplace=False) - self.N_drop_masks = 16000 - self.drop_mask_cnt = 0 - self.drop_mask_te = torch.tensor([1.0]).float() - self.w = nn.Linear(self.input_size, 2 * self.hidden_size, bias=False) - self.u = nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=False) - - # Initializing batch norm - self.normalize = False - - if normalization == "batchnorm": - self.norm = nn.BatchNorm1d(2 * self.hidden_size, momentum=0.05) - self.normalize = True - - elif normalization == "layernorm": - self.norm = torch.nn.LayerNorm(2 * self.hidden_size) - self.normalize = True - else: - # Normalization is disabled here. self.norm is only formally - # initialized to avoid jit issues. - self.norm = torch.nn.LayerNorm(2 * self.hidden_size) - self.normalize = True - - # Initial state - self.register_buffer("h_init", torch.zeros(1, self.hidden_size)) - - # Setting the activation function - if nonlinearity == "tanh": - self.act = torch.nn.Tanh() - elif nonlinearity == "sin": - self.act = torch.sin - elif nonlinearity == "leaky_relu": - self.act = torch.nn.LeakyReLU() - else: - self.act = torch.nn.ReLU() - - def forward(self, x, hx: Optional[Tensor] = None): - # type: (Tensor, Optional[Tensor]) -> Tensor # noqa F821 - """Returns the output of the liGRU layer. - - Arguments - --------- - x : torch.Tensor - Input tensor. - """ - if self.bidirectional: - x_flip = x.flip(1) - x = torch.cat([x, x_flip], dim=0) - - # Change batch size if needed - self._change_batch_size(x) - - # Feed-forward affine transformations (all steps in parallel) - w = self.w(x) - - # Apply batch normalization - if self.normalize: - w_bn = self.norm(w.reshape(w.shape[0] * w.shape[1], w.shape[2])) - w = w_bn.reshape(w.shape[0], w.shape[1], w.shape[2]) - - # Processing time steps - if hx is not None: - h = self._ligru_cell(w, hx) - else: - h = self._ligru_cell(w, self.h_init) - - if self.bidirectional: - h_f, h_b = h.chunk(2, dim=0) - h_b = h_b.flip(1) - h = torch.cat([h_f, h_b], dim=2) - - return h - - def _ligru_cell(self, w, ht): - """Returns the hidden states for each time step. - - Arguments - --------- - wx : torch.Tensor - Linearly transformed input. - """ - hiddens = [] - - # Sampling dropout mask - drop_mask = self._sample_drop_mask(w) - - # Loop over time axis - for k in range(w.shape[1]): - gates = w[:, k] + self.u(ht) - at, zt = gates.chunk(2, 1) - zt = torch.sigmoid(zt) - hcand = self.act(at) * drop_mask - ht = zt * ht + (1 - zt) * hcand - hiddens.append(ht) - - # Stacking hidden states - h = torch.stack(hiddens, dim=1) - return h - - def _init_drop(self, batch_size): - """Initializes the recurrent dropout operation. To speed it up, - the dropout masks are sampled in advance. - """ - self.N_drop_masks = 16000 - self.drop_mask_cnt = 0 - - self.register_buffer( - "drop_masks", - self.drop(torch.ones(self.N_drop_masks, self.hidden_size)).data, - ) - self.register_buffer("drop_mask_te", torch.tensor([1.0]).float()) - - def _sample_drop_mask(self, w): - """Selects one of the pre-defined dropout masks""" - if self.training: - - # Sample new masks when needed - if self.drop_mask_cnt + self.batch_size > self.N_drop_masks: - self.drop_mask_cnt = 0 - self.drop_masks = self.drop( - torch.ones(self.N_drop_masks, self.hidden_size, device=w.device) - ).data - - # Sampling the mask - left_boundary = self.drop_mask_cnt - right_boundary = self.drop_mask_cnt + self.batch_size - drop_mask = self.drop_masks[left_boundary:right_boundary] - self.drop_mask_cnt = self.drop_mask_cnt + self.batch_size - - else: - self.drop_mask_te = self.drop_mask_te.to(w.device) - drop_mask = self.drop_mask_te - - return drop_mask - - def _change_batch_size(self, x): - """This function changes the batch size when it is different from - the one detected in the initialization method. This might happen in - the case of multi-gpu or when we have different batch sizes in train - and test. We also update the h_int and drop masks. - """ - if self.batch_size != x.shape[0]: - self.batch_size = x.shape[0] - - if self.training: - self.drop_masks = self.drop( - torch.ones( - self.N_drop_masks, - self.hidden_size, - device=x.device, - ) - ).data diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py deleted file mode 100644 index 0d77bc512..000000000 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from model import TdnnLiGRU -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.1, - help=""" - Used only when method is whole-lattice-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 2, - "num_classes": 41, - "sample_rate": 16000, - "search_beam": 20, - "output_beam": 5, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = TdnnLiGRU( - num_features=params.feature_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method == "whole-lattice-rescoring": - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - features = features.permute(0, 2, 1) # now features is (N, C, T) - - with torch.no_grad(): - nnet_output = model(features) - # nnet_output is (N, T, C) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py deleted file mode 100644 index 48b7feda0..000000000 --- a/egs/timit/ASR/tdnn_ligru_ctc/train.py +++ /dev/null @@ -1,601 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import TimitAsrDataModule -from lhotse.utils import fix_random_seed -from model import TdnnLiGRU -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.optim.lr_scheduler import StepLR -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=25, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - """ - params = AttributeDict( - { - "exp_dir": Path("tdnn_ligru_ctc/exp"), - "lang_dir": Path("data/lang_phone"), - "lr": 1e-3, - "feature_dim": 80, - "weight_decay": 5e-4, - "subsampling_factor": 2, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 200, - "valid_interval": 1000, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of TdnnLstm in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - feature = feature.permute(0, 2, 1) # now feature is (N, C, T) - assert feature.ndim == 3 - feature = feature.to(device) - - with torch.set_grad_enabled(is_training): - nnet_output = model(feature) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervisions = batch["supervisions"] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - - model = TdnnLiGRU( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - scheduler = StepLR(optimizer, step_size=2, gamma=0.8) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - scheduler.load_state_dict(checkpoints["scheduler"]) - - timit = TimitAsrDataModule(args) - train_dl = timit.train_dataloaders() - valid_dl = timit.valid_dataloaders() - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if epoch > params.start_epoch: - logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}") - - if tb_writer is not None: - tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, - ) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - scheduler.step() - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - TimitAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/timit/ASR/tdnn_lstm_ctc/__init__.py b/egs/timit/ASR/tdnn_lstm_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py deleted file mode 100644 index 8606a490b..000000000 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import List, Union - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool - - -class TimitAsrDataModule(DataModule): - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--feature-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - def train_dataloaders(self) -> DataLoader: - logging.info("About to get train cuts") - cuts_train = self.train_cuts() - - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz") - - logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20))] - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[ - "num_frame_masks" - ] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms = [ - SpecAugment( - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ] - - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self) -> DataLoader: - logging.info("About to get dev cuts") - cuts_valid = self.valid_cuts() - - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = SimpleCutSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: - cuts = self.test_cuts() - is_list = isinstance(cuts, list) - test_loaders = [] - if not is_list: - cuts = [cuts] - - for cuts_test in cuts: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = SimpleCutSampler(cuts_test, max_duration=self.args.max_duration) - logging.debug("About to create test dataloader") - test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) - test_loaders.append(test_dl) - - if is_list: - return test_loaders - else: - return test_loaders[0] - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.feature_dir / "timit_cuts_TRAIN.jsonl.gz" - ) - - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.feature_dir / "timit_cuts_DEV.jsonl.gz" - ) - - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.debug("About to get test cuts") - cuts_test = load_manifest_lazy( - self.args.feature_dir / "timit_cuts_TEST.jsonl.gz" - ) - - return cuts_test diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py deleted file mode 100644 index 502a48def..000000000 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ /dev/null @@ -1,486 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import TimitAsrDataModule -from model import TdnnLstm - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=25, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - parser.add_argument( - "--method", - type=str, - default="whole-lattice-rescoring", - help="""Decoding method. - Supported values are: - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp/"), - "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), - "feature_dim": 80, - "subsampling_factor": 3, - "search_beam": 20, - "output_beam": 5, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - batch: dict, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = HLG.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - feature = feature.permute(0, 2, 1) # now feature is (N, C, T) - - nnet_output = model(feature) - # nnet_output is (N, T, C) - - supervisions = batch["supervisions"] - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-{params.num_paths}" - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] - - lm_scale_list = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09] - lm_scale_list += [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - else: - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - - ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - lexicon=lexicon, - G=G, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out PERs, per-phone error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"per-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tPER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, PER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - TimitAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) - HLG = HLG.to(device) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") - G = k2.Fsa.from_dict(d).to(device) - - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") - return - - model.to(device) - model.eval() - - # we need cut ids to display recognition results. - args.return_cuts = True - timit = TimitAsrDataModule(args) - test_set = "TEST" - test_dl = timit.test_dataloaders() - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - lexicon=lexicon, - G=G, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py deleted file mode 100644 index e211ad80d..000000000 --- a/egs/timit/ASR/tdnn_lstm_ctc/model.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class TdnnLstm(nn.Module): - def __init__( - self, num_features: int, num_classes: int, subsampling_factor: int = 3 - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - It reduces the number of output frames by this factor. - """ - super().__init__() - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=512, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - nn.Conv1d( - in_channels=512, - out_channels=512, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - nn.Conv1d( - in_channels=512, - out_channels=512, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - nn.Conv1d( - in_channels=512, - out_channels=512, - kernel_size=3, - stride=self.subsampling_factor, # stride: subsampling_factor! - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=512, affine=False), - ) - self.lstms = nn.ModuleList( - [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)] - ) - self.lstm_bnorms = nn.ModuleList( - [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)] - ) - self.dropout = nn.Dropout(0.2) - self.linear = nn.Linear(in_features=512, out_features=self.num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Its shape is [N, C, T] - Returns: - The output tensor has shape [N, T, C] - """ - x = self.tdnn(x) - x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it - for lstm, bnorm in zip(self.lstms, self.lstm_bnorms): - x_new, _ = lstm(x) - x_new = bnorm(x_new.permute(1, 2, 0)).permute( - 2, 0, 1 - ) # (T, N, C) -> (N, C, T) -> (T, N, C) - x_new = self.dropout(x_new) - x = x_new + x # skip connections - x = x.transpose( - 1, 0 - ) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim - x = self.linear(x) - x = nn.functional.log_softmax(x, dim=-1) - return x diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py deleted file mode 100644 index f06c8c211..000000000 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from model import TdnnLstm -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.8, - help=""" - Used only when method is whole-lattice-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 3, - "num_classes": 41, - "sample_rate": 16000, - "search_beam": 20, - "output_beam": 5, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method == "whole-lattice-rescoring": - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - features = features.permute(0, 2, 1) # now features is (N, C, T) - - with torch.no_grad(): - nnet_output = model(features) - # nnet_output is (N, T, C) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py deleted file mode 100644 index be1ecffaa..000000000 --- a/egs/timit/ASR/tdnn_lstm_ctc/train.py +++ /dev/null @@ -1,601 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import TimitAsrDataModule -from lhotse.utils import fix_random_seed -from model import TdnnLstm -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.optim.lr_scheduler import StepLR -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - """ - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp"), - "lang_dir": Path("data/lang_phone"), - "lr": 1e-3, - "feature_dim": 80, - "weight_decay": 5e-4, - "subsampling_factor": 3, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 200, - "valid_interval": 1000, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of TdnnLstm in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - feature = feature.permute(0, 2, 1) # now feature is (N, C, T) - assert feature.ndim == 3 - feature = feature.to(device) - - with torch.set_grad_enabled(is_training): - nnet_output = model(feature) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervisions = batch["supervisions"] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - scheduler = StepLR(optimizer, step_size=8, gamma=0.8) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - scheduler.load_state_dict(checkpoints["scheduler"]) - - timit = TimitAsrDataModule(args) - train_dl = timit.train_dataloaders() - valid_dl = timit.valid_dataloaders() - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if epoch > params.start_epoch: - logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}") - - if tb_writer is not None: - tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, - ) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - scheduler.step() - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - TimitAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/vctk/TTS/README.md b/egs/vctk/TTS/README.md deleted file mode 100644 index c2703dbe2..000000000 --- a/egs/vctk/TTS/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# Introduction - -This CSTR VCTK Corpus includes speech data uttered by 110 English speakers with various accents. Each speaker reads out about 400 sentences, which were selected from a newspaper, the rainbow passage and an elicitation paragraph used for the speech accent archive. -The newspaper texts were taken from Herald Glasgow, with permission from Herald & Times Group. Each speaker has a different set of the newspaper texts selected based a greedy algorithm that increases the contextual and phonetic coverage. -The details of the text selection algorithms are described in the following paper: [C. Veaux, J. Yamagishi and S. King, "The voice bank corpus: Design, collection and data analysis of a large regional accent speech database,"](https://doi.org/10.1109/ICSDA.2013.6709856). - -The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk/handle/10283/3443). - -# VITS - -This recipe provides a VITS model trained on the VCTK dataset. - -Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2024-03-18), note that this model was pretrained on the Edinburgh DataShare VCTK dataset. - -For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html). - -The training command is given below: -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3" -./vits/train.py \ - --world-size 4 \ - --num-epochs 1000 \ - --start-epoch 1 \ - --exp-dir vits/exp \ - --tokens data/tokens.txt - --max-duration 350 -``` - -To inference, use: -``` -./vits/infer.py \ - --epoch 1000 \ - --exp-dir vits/exp \ - --tokens data/tokens.txt \ - --max-duration 500 -``` \ No newline at end of file diff --git a/egs/vctk/TTS/local/compute_spectrogram_vctk.py b/egs/vctk/TTS/local/compute_spectrogram_vctk.py deleted file mode 100755 index 440ac1245..000000000 --- a/egs/vctk/TTS/local/compute_spectrogram_vctk.py +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the VCTK dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/spectrogram. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - LilcomChunkyWriter, - Spectrogram, - SpectrogramConfig, - load_manifest, -) -from lhotse.audio import RecordingSet -from lhotse.supervision import SupervisionSet - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_spectrogram_vctk(): - src_dir = Path("data/manifests") - output_dir = Path("data/spectrogram") - num_jobs = min(32, os.cpu_count()) - - sampling_rate = 22050 - frame_length = 1024 / sampling_rate # (in second) - frame_shift = 256 / sampling_rate # (in second) - use_fft_mag = True - - prefix = "vctk" - suffix = "jsonl.gz" - partition = "all" - - recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet - ).resample(sampling_rate=sampling_rate) - supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet - ) - - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=frame_length, - frame_shift=frame_shift, - use_fft_mag=use_fft_mag, - ) - extractor = Spectrogram(config) - - with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=recordings, supervisions=supervisions - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_spectrogram_vctk() diff --git a/egs/vctk/TTS/local/display_manifest_statistics.py b/egs/vctk/TTS/local/display_manifest_statistics.py deleted file mode 100755 index 0472e2cea..000000000 --- a/egs/vctk/TTS/local/display_manifest_statistics.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in vits/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - path = "./data/spectrogram/vctk_cuts_all.jsonl.gz" - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Cut statistics: -╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 43873 │ -├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 41:02:18 │ -├───────────────────────────┼──────────┤ -│ mean │ 3.4 │ -├───────────────────────────┼──────────┤ -│ std │ 1.2 │ -├───────────────────────────┼──────────┤ -│ min │ 1.2 │ -├───────────────────────────┼──────────┤ -│ 25% │ 2.6 │ -├───────────────────────────┼──────────┤ -│ 50% │ 3.1 │ -├───────────────────────────┼──────────┤ -│ 75% │ 3.8 │ -├───────────────────────────┼──────────┤ -│ 99% │ 8.0 │ -├───────────────────────────┼──────────┤ -│ 99.5% │ 9.1 │ -├───────────────────────────┼──────────┤ -│ 99.9% │ 12.1 │ -├───────────────────────────┼──────────┤ -│ max │ 16.6 │ -├───────────────────────────┼──────────┤ -│ Recordings available: │ 43873 │ -├───────────────────────────┼──────────┤ -│ Features available: │ 43873 │ -├───────────────────────────┼──────────┤ -│ Supervisions available: │ 43873 │ -╘═══════════════════════════╧══════════╛ -SUPERVISION custom fields: -Speech duration statistics: -╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 41:02:18 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 41:02:18 │ 100.00% of recording │ -├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total silence duration │ 00:00:01 │ 0.00% of recording │ -╘══════════════════════════════╧══════════╧══════════════════════╛ -""" diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py deleted file mode 120000 index afc29a22b..000000000 --- a/egs/vctk/TTS/local/prepare_token_file.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/local/prepare_token_file.py \ No newline at end of file diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py deleted file mode 100755 index 0748eba5a..000000000 --- a/egs/vctk/TTS/local/prepare_tokens_vctk.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file reads the texts in given manifest and save the new cuts with phoneme tokens. -""" - -import logging -from pathlib import Path - -import tacotron_cleaner.cleaners -from lhotse import CutSet, load_manifest -from piper_phonemize import phonemize_espeak -from tqdm.auto import tqdm - - -def prepare_tokens_vctk(): - output_dir = Path("data/spectrogram") - prefix = "vctk" - suffix = "jsonl.gz" - partition = "all" - - cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - new_cuts = [] - for cut in tqdm(cut_set): - # Each cut only contains one supervision - assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) - text = cut.supervisions[0].text - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens_list = phonemize_espeak(text, "en-us") - tokens = [] - for t in tokens_list: - tokens.extend(t) - cut.tokens = tokens - new_cuts.append(cut) - - new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - prepare_tokens_vctk() diff --git a/egs/vctk/TTS/local/validate_manifest.py b/egs/vctk/TTS/local/validate_manifest.py deleted file mode 100755 index cd466303e..000000000 --- a/egs/vctk/TTS/local/validate_manifest.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/spectrogram/ljspeech_cuts_all.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset.speech_synthesis import validate_for_tts - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet) - - validate_for_tts(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh deleted file mode 100755 index aab075312..000000000 --- a/egs/vctk/TTS/prepare.sh +++ /dev/null @@ -1,140 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=0 -stop_stage=100 -use_edinburgh_vctk_url=true - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: build monotonic_align lib" - if [ ! -d vits/monotonic_align/build ]; then - cd vits/monotonic_align - python setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib already built" - fi -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/VCTK, - # you can create a symlink - # - # ln -sfv /path/to/VCTK $dl_dir/VCTK - # - if [ ! -d $dl_dir/VCTK ]; then - lhotse download vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare VCTK manifest" - # We assume that you have downloaded the VCTK corpus - # to $dl_dir/VCTK - mkdir -p data/manifests - if [ ! -e data/manifests/.vctk.done ]; then - lhotse prepare vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir/VCTK data/manifests - touch data/manifests/.vctk.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute spectrogram for VCTK" - mkdir -p data/spectrogram - if [ ! -e data/spectrogram/.vctk.done ]; then - ./local/compute_spectrogram_vctk.py - touch data/spectrogram/.vctk.done - fi - - if [ ! -e data/spectrogram/.vctk-validated.done ]; then - log "Validating data/fbank for VCTK" - ./local/validate_manifest.py \ - data/spectrogram/vctk_cuts_all.jsonl.gz - touch data/spectrogram/.vctk-validated.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare phoneme tokens for VCTK" - # We assume you have installed piper_phonemize and espnet_tts_frontend. - # If not, please install them with: - # - piper_phonemize: - # refer to https://github.com/rhasspy/piper-phonemize, - # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: - # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ - if [ ! -e data/spectrogram/.vctk_with_token.done ]; then - ./local/prepare_tokens_vctk.py - mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \ - data/spectrogram/vctk_cuts_all.jsonl.gz - touch data/spectrogram/.vctk_with_token.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Split the VCTK cuts into train, valid and test sets" - if [ ! -e data/spectrogram/.vctk_split.done ]; then - lhotse subset --last 600 \ - data/spectrogram/vctk_cuts_all.jsonl.gz \ - data/spectrogram/vctk_cuts_validtest.jsonl.gz - lhotse subset --first 100 \ - data/spectrogram/vctk_cuts_validtest.jsonl.gz \ - data/spectrogram/vctk_cuts_valid.jsonl.gz - lhotse subset --last 500 \ - data/spectrogram/vctk_cuts_validtest.jsonl.gz \ - data/spectrogram/vctk_cuts_test.jsonl.gz - - rm data/spectrogram/vctk_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 )) - lhotse subset --first $n \ - data/spectrogram/vctk_cuts_all.jsonl.gz \ - data/spectrogram/vctk_cuts_train.jsonl.gz - touch data/spectrogram/.vctk_split.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Generate token file" - # We assume you have installed piper_phonemize and espnet_tts_frontend. - # If not, please install them with: - # - piper_phonemize: - # refer to https://github.com/rhasspy/piper-phonemize, - # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: - # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ - if [ ! -e data/tokens.txt ]; then - ./local/prepare_token_file.py --tokens data/tokens.txt - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Generate speakers file" - if [ ! -e data/speakers.txt ]; then - gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \ - | jq '.speaker' | sed 's/"//g' \ - | sort | uniq > data/speakers.txt - fi -fi diff --git a/egs/vctk/TTS/shared b/egs/vctk/TTS/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/vctk/TTS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py deleted file mode 120000 index 9972b476f..000000000 --- a/egs/vctk/TTS/vits/duration_predictor.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py deleted file mode 100755 index d00450f08..000000000 --- a/egs/vctk/TTS/vits/export-onnx.py +++ /dev/null @@ -1,295 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script exports a VITS model from PyTorch to ONNX. - -Export the model to ONNX: -./vits/export-onnx.py \ - --epoch 1000 \ - --exp-dir vits/exp \ - --tokens data/tokens.txt - -It will generate two files inside vits/exp: - - vits-epoch-1000.onnx - - vits-epoch-1000.int8.onnx (quantizated model) - -See ./test_onnx.py for how to use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import onnx -import torch -import torch.nn as nn -from onnxruntime.quantization import QuantType, quantize_dynamic -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--speakers", - type=Path, - default=Path("data/speakers.txt"), - help="Path to speakers.txt file.", - ) - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -class OnnxModel(nn.Module): - """A wrapper for VITS generator.""" - - def __init__(self, model: nn.Module): - """ - Args: - model: - A VITS generator. - frame_shift: - The frame shift in samples. - """ - super().__init__() - self.model = model - - def forward( - self, - tokens: torch.Tensor, - tokens_lens: torch.Tensor, - noise_scale: float = 0.667, - alpha: float = 1.0, - noise_scale_dur: float = 0.8, - speaker: int = 20, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of VITS.inference_batch - - Args: - tokens: - Input text token indexes (1, T_text) - tokens_lens: - Number of tokens of shape (1,) - noise_scale (float): - Noise scale parameter for flow. - noise_scale_dur (float): - Noise scale parameter for duration predictor. - speaker (int): - Speaker ID. - alpha (float): - Alpha parameter to control the speed of generated speech. - - Returns: - Return a tuple containing: - - audio, generated wavform tensor, (B, T_wav) - """ - audio, _, _ = self.model.inference( - text=tokens, - text_lengths=tokens_lens, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - sids=speaker, - alpha=alpha, - ) - return audio - - -def export_model_onnx( - model: nn.Module, - model_filename: str, - vocab_size: int, - n_speakers: int, - opset_version: int = 11, -) -> None: - """Export the given generator model to ONNX format. - The exported model has one input: - - - tokens, a tensor of shape (1, T_text); dtype is torch.int64 - - and it has one output: - - - audio, a tensor of shape (1, T'); dtype is torch.float32 - - Args: - model: - The VITS generator. - model_filename: - The filename to save the exported ONNX model. - vocab_size: - Number of tokens used in training. - opset_version: - The opset version to use. - """ - tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) - noise_scale = torch.tensor([1], dtype=torch.float32) - noise_scale_dur = torch.tensor([1], dtype=torch.float32) - alpha = torch.tensor([1], dtype=torch.float32) - speaker = torch.tensor([1], dtype=torch.int64) - - torch.onnx.export( - model, - (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker), - model_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "tokens", - "tokens_lens", - "noise_scale", - "alpha", - "noise_scale_dur", - "speaker", - ], - output_names=["audio"], - dynamic_axes={ - "tokens": {0: "N", 1: "T"}, - "tokens_lens": {0: "N"}, - "audio": {0: "N", 1: "T"}, - "speaker": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "vits", - "version": "1", - "model_author": "k2-fsa", - "comment": "icefall", # must be icefall for models from icefall - "language": "English", - "voice": "en-us", # Choose your language appropriately - "has_espeak": 1, - "n_speakers": n_speakers, - "sample_rate": 22050, # Must match the real sample rate - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=model_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - with open(args.speakers) as f: - speaker_map = {line.strip(): i for i, line in enumerate(f)} - params.num_spks = len(speaker_map) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - model = model.generator - model.to("cpu") - model.eval() - - model = OnnxModel(model=model) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}") - - suffix = f"epoch-{params.epoch}" - - opset_version = 13 - - logging.info("Exporting encoder") - model_filename = params.exp_dir / f"vits-{suffix}.onnx" - export_model_onnx( - model, - model_filename, - params.vocab_size, - params.num_spks, - opset_version=opset_version, - ) - logging.info(f"Exported generator to {model_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - weight_type=QuantType.QUInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py deleted file mode 120000 index e65d91ea7..000000000 --- a/egs/vctk/TTS/vits/flow.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py deleted file mode 120000 index 611679bfa..000000000 --- a/egs/vctk/TTS/vits/generator.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/hifigan.py b/egs/vctk/TTS/vits/hifigan.py deleted file mode 120000 index 5ac025de7..000000000 --- a/egs/vctk/TTS/vits/hifigan.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py deleted file mode 100755 index 2e1abdefb..000000000 --- a/egs/vctk/TTS/vits/infer.py +++ /dev/null @@ -1,273 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script performs model inference on test set. - -Usage: -./vits/infer.py \ - --epoch 1000 \ - --exp-dir ./vits/exp \ - --max-duration 500 -""" - - -import argparse -import logging -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path -from typing import Dict, List - -import k2 -import torch -import torch.nn as nn -import torchaudio -from tokenizer import Tokenizer -from train import get_model, get_params -from tts_datamodule import VctkTtsDataModule - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, setup_logger - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - return parser - - -def infer_dataset( - dl: torch.utils.data.DataLoader, - subset: str, - params: AttributeDict, - model: nn.Module, - tokenizer: Tokenizer, - speaker_map: Dict[str, int], -) -> None: - """Decode dataset. - The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - tokenizer: - Used to convert text to phonemes. - """ - - # Background worker save audios to disk. - def _save_worker( - subset: str, - batch_size: int, - cut_ids: List[str], - audio: torch.Tensor, - audio_pred: torch.Tensor, - audio_lens: List[int], - audio_lens_pred: List[int], - ): - for i in range(batch_size): - torchaudio.save( - str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), - audio[i : i + 1, : audio_lens[i]], - sample_rate=params.sampling_rate, - ) - torchaudio.save( - str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), - audio_pred[i : i + 1, : audio_lens_pred[i]], - sample_rate=params.sampling_rate, - ) - - device = next(model.parameters()).device - num_cuts = 0 - log_interval = 5 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - futures = [] - with ThreadPoolExecutor(max_workers=1) as executor: - for batch_idx, batch in enumerate(dl): - batch_size = len(batch["tokens"]) - - tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - speakers = ( - torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]) - .int() - .to(device) - ) - - audio = batch["audio"] - audio_lens = batch["audio_lens"].tolist() - cut_ids = [cut.id for cut in batch["cut"]] - - audio_pred, _, durations = model.inference_batch( - text=tokens, - text_lengths=tokens_lens, - sids=speakers, - ) - audio_pred = audio_pred.detach().cpu() - # convert to samples - audio_lens_pred = ( - (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() - ) - - futures.append( - executor.submit( - _save_worker, - subset, - batch_size, - cut_ids, - audio, - audio_pred, - audio_lens, - audio_lens_pred, - ) - ) - - num_cuts += batch_size - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - # return results - for f in futures: - f.result() - - -@torch.no_grad() -def main(): - parser = get_parser() - VctkTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.suffix = f"epoch-{params.epoch}" - - params.res_dir = params.exp_dir / "infer" / params.suffix - params.save_wav_dir = params.res_dir / "wav" - params.save_wav_dir.mkdir(parents=True, exist_ok=True) - - setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") - logging.info("Infer started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - # we need cut ids to display recognition results. - args.return_cuts = True - vctk = VctkTtsDataModule(args) - speaker_map = vctk.speakers() - params.num_spks = len(speaker_map) - - logging.info(f"Device: {device}") - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - model.to(device) - model.eval() - - num_param_g = sum([p.numel() for p in model.generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - test_cuts = vctk.test_cuts() - test_dl = vctk.test_dataloaders(test_cuts) - - valid_cuts = vctk.valid_cuts() - valid_dl = vctk.valid_dataloaders(valid_cuts) - - infer_sets = {"test": test_dl, "valid": valid_dl} - - for subset, dl in infer_sets.items(): - save_wav_dir = params.res_dir / "wav" / subset - save_wav_dir.mkdir(parents=True, exist_ok=True) - - logging.info(f"Processing {subset} set, saving to {save_wav_dir}") - - infer_dataset( - dl=dl, - subset=subset, - params=params, - model=model, - tokenizer=tokenizer, - speaker_map=speaker_map, - ) - - logging.info(f"Wav files are saved to {params.save_wav_dir}") - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py deleted file mode 120000 index 672e5ff68..000000000 --- a/egs/vctk/TTS/vits/loss.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/monotonic_align b/egs/vctk/TTS/vits/monotonic_align deleted file mode 120000 index 71934e7cc..000000000 --- a/egs/vctk/TTS/vits/monotonic_align +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py deleted file mode 120000 index 41d64a3a6..000000000 --- a/egs/vctk/TTS/vits/posterior_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py deleted file mode 120000 index f979adbf0..000000000 --- a/egs/vctk/TTS/vits/residual_coupling.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py deleted file mode 100755 index ae6587338..000000000 --- a/egs/vctk/TTS/vits/test_onnx.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to test the exported onnx model by vits/export-onnx.py - -Use the onnx model to generate a wav: -./vits/test_onnx.py \ - --model-filename vits/exp/vits-epoch-1000.onnx \ - --tokens data/tokens.txt -""" - - -import argparse -import logging -from pathlib import Path - -import onnxruntime as ort -import torch -import torchaudio -from tokenizer import Tokenizer - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the onnx model.", - ) - - parser.add_argument( - "--speakers", - type=Path, - default=Path("data/speakers.txt"), - help="Path to speakers.txt file.", - ) - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - return parser - - -class OnnxModel: - def __init__(self, model_filename: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.model = ort.InferenceSession( - model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - - def __call__( - self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor - ) -> torch.Tensor: - """ - Args: - tokens: - A 1-D tensor of shape (1, T) - Returns: - A tensor of shape (1, T') - """ - noise_scale = torch.tensor([0.667], dtype=torch.float32) - noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) - alpha = torch.tensor([1.0], dtype=torch.float32) - - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: tokens.numpy(), - self.model.get_inputs()[1].name: tokens_lens.numpy(), - self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: alpha.numpy(), - self.model.get_inputs()[4].name: noise_scale_dur.numpy(), - self.model.get_inputs()[5].name: speaker.numpy(), - }, - )[0] - return torch.from_numpy(out) - - -def main(): - args = get_parser().parse_args() - - tokenizer = Tokenizer(args.tokens) - - with open(args.speakers) as f: - speaker_map = {line.strip(): i for i, line in enumerate(f)} - args.num_spks = len(speaker_map) - - logging.info("About to create onnx model") - model = OnnxModel(args.model_filename) - - text = "I went there to see the land, the people and how their system works, end quote." - tokens = tokenizer.texts_to_token_ids( - [text], intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = torch.tensor(tokens) # (1, T) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) - speaker = torch.tensor([1], dtype=torch.int64) # (1, ) - audio = model(tokens, tokens_lens, speaker) # (1, T') - - torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) - logging.info("Saved to test_onnx.wav") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/vctk/TTS/vits/text_encoder.py b/egs/vctk/TTS/vits/text_encoder.py deleted file mode 120000 index 0efba277e..000000000 --- a/egs/vctk/TTS/vits/text_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tokenizer.py b/egs/vctk/TTS/vits/tokenizer.py deleted file mode 120000 index 057b0dc4b..000000000 --- a/egs/vctk/TTS/vits/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py deleted file mode 100755 index 4686de169..000000000 --- a/egs/vctk/TTS/vits/train.py +++ /dev/null @@ -1,1002 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import numpy as np -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import VctkTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, setup_logger, str2bool - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1000, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--lr", type=float, default=2.0e-4, help="The base learning rate." - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=20, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - # training params - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 200, - "env_info": get_env_info(), - "sampling_rate": 22050, - "frame_shift": 256, - "frame_length": 1024, - "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "n_mels": 80, - "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss - "lambda_mel": 45.0, # loss scaling coefficient for Mel loss - "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss - "lambda_dur": 1.0, # loss scaling coefficient for duration loss - "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def get_model(params: AttributeDict) -> nn.Module: - mel_loss_params = { - "n_mels": params.n_mels, - "frame_length": params.frame_length, - "frame_shift": params.frame_shift, - } - generator_params = { - "hidden_channels": 192, - "spks": params.num_spks, - "langs": None, - "spk_embed_dim": None, - "global_channels": 256, - "segment_size": 32, - "text_encoder_attention_heads": 2, - "text_encoder_ffn_expand": 4, - "text_encoder_cnn_module_kernel": 5, - "text_encoder_blocks": 6, - "text_encoder_dropout_rate": 0.1, - "decoder_kernel_size": 7, - "decoder_channels": 512, - "decoder_upsample_scales": [8, 8, 2, 2], - "decoder_upsample_kernel_sizes": [16, 16, 4, 4], - "decoder_resblock_kernel_sizes": [3, 7, 11], - "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "use_weight_norm_in_decoder": True, - "posterior_encoder_kernel_size": 5, - "posterior_encoder_layers": 16, - "posterior_encoder_stacks": 1, - "posterior_encoder_base_dilation": 1, - "posterior_encoder_dropout_rate": 0.0, - "use_weight_norm_in_posterior_encoder": True, - "flow_flows": 4, - "flow_kernel_size": 5, - "flow_base_dilation": 1, - "flow_layers": 4, - "flow_dropout_rate": 0.0, - "use_weight_norm_in_flow": True, - "use_only_mean_in_flow": True, - "stochastic_duration_predictor_kernel_size": 3, - "stochastic_duration_predictor_dropout_rate": 0.5, - "stochastic_duration_predictor_flows": 4, - "stochastic_duration_predictor_dds_conv_layers": 3, - } - model = VITS( - vocab_size=params.vocab_size, - feature_dim=params.feature_dim, - sampling_rate=params.sampling_rate, - generator_params=generator_params, - mel_loss_params=mel_loss_params, - lambda_adv=params.lambda_adv, - lambda_mel=params.lambda_mel, - lambda_feat_match=params.lambda_feat_match, - lambda_dur=params.lambda_dur, - lambda_kl=params.lambda_kl, - ) - return model - - -def prepare_input( - batch: dict, - tokenizer: Tokenizer, - device: torch.device, - speaker_map: Dict[str, int], -): - """Parse batch data""" - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - tokens = batch["tokens"] - speakers = ( - torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) - ) - - tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True - ) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - optimizer_g: Optimizer, - optimizer_d: Optimizer, - scheduler_g: LRSchedulerType, - scheduler_d: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - speaker_map: Dict[str, int], - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - tokenizer: - Used to convert text to phonemes. - optimizer_g: - The optimizer for generator. - optimizer_d: - The optimizer for discriminator. - scheduler_g: - The learning rate scheduler for generator, we call step() every epoch. - scheduler_d: - The learning rate scheduler for discriminator, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - - batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - speakers, - ) = prepare_input(batch, tokenizer, device, speaker_map) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - try: - with autocast(enabled=params.use_fp16): - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - sids=speakers, - forward_generator=False, - ) - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - # update discriminator - optimizer_d.zero_grad() - scaler.scale(loss_d).backward() - scaler.step(optimizer_d) - - with autocast(enabled=params.use_fp16): - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - sids=speakers, - forward_generator=True, - return_sample=params.batch_idx_train % params.log_interval == 0, - ) - for k, v in stats_g.items(): - if "returned_sample" not in k: - loss_info[k] = v * batch_size - # update generator - optimizer_g.zero_grad() - scaler.scale(loss_g).backward() - scaler.step(optimizer_g) - scaler.update() - - # summary stats - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr_g = max(scheduler_g.get_last_lr()) - cur_lr_d = max(scheduler_d.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate_g", cur_lr_g, params.batch_idx_train - ) - tb_writer.add_scalar( - "train/learning_rate_d", cur_lr_d, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - if "returned_sample" in stats_g: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] - tb_writer.add_audio( - "train/speech_hat_", - speech_hat_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/speech_", - speech_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_image( - "train/mel_hat_", - plot_feature(mel_hat_), - params.batch_idx_train, - dataformats="HWC", - ) - tb_writer.add_image( - "train/mel_", - plot_feature(mel_), - params.batch_idx_train, - dataformats="HWC", - ) - - if ( - params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info, (speech_hat, speech) = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - speaker_map=speaker_map, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - tb_writer.add_audio( - "train/valid_speech_hat", - speech_hat, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valid_speech", - speech, - params.batch_idx_train, - params.sampling_rate, - ) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - speaker_map: Dict[str, int], - world_size: int = 1, - rank: int = 0, -) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - returned_sample = None - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - speakers, - ) = prepare_input(batch, tokenizer, device, speaker_map) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - sids=speakers, - forward_generator=False, - ) - assert loss_d.requires_grad is False - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - sids=speakers, - forward_generator=True, - ) - assert loss_g.requires_grad is False - for k, v in stats_g.items(): - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - # infer for first batch: - if batch_idx == 0 and rank == 0: - inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference( - text=tokens[0, : tokens_lens[0].item()], - sids=speakers[0], - ) - audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = ( - (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - ) - assert audio_len_pred == len(audio_pred), ( - audio_len_pred, - len(audio_pred), - ) - audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() - returned_sample = (audio_pred, audio_gt) - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss, returned_sample - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - tokenizer: Tokenizer, - optimizer_g: torch.optim.Optimizer, - optimizer_d: torch.optim.Optimizer, - speaker_map: Dict[str, int], - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - speakers, - ) = prepare_input(batch, tokenizer, device, speaker_map) - try: - # for discriminator - with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - sids=speakers, - forward_generator=False, - ) - optimizer_d.zero_grad() - loss_d.backward() - # for generator - with autocast(enabled=params.use_fp16): - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - sids=speakers, - forward_generator=True, - ) - optimizer_g.zero_grad() - loss_g.backward() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - vctk = VctkTtsDataModule(args) - - train_cuts = vctk.train_cuts() - speaker_map = vctk.speakers() - params.num_spks = len(speaker_map) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - generator = model.generator - discriminator = model.discriminator - - num_param_g = sum([p.numel() for p in generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer_g = torch.optim.AdamW( - generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - optimizer_d = torch.optim.AdamW( - discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) - - if checkpoints is not None: - # load state_dict for optimizers - if "optimizer_g" in checkpoints: - logging.info("Loading optimizer_g state dict") - optimizer_g.load_state_dict(checkpoints["optimizer_g"]) - if "optimizer_d" in checkpoints: - logging.info("Loading optimizer_d state dict") - optimizer_d.load_state_dict(checkpoints["optimizer_d"]) - - # load state_dict for schedulers - if "scheduler_g" in checkpoints: - logging.info("Loading scheduler_g state dict") - scheduler_g.load_state_dict(checkpoints["scheduler_g"]) - if "scheduler_d" in checkpoints: - logging.info("Loading scheduler_d state dict") - scheduler_d.load_state_dict(checkpoints["scheduler_d"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = vctk.train_dataloaders(train_cuts) - - valid_cuts = vctk.valid_cuts() - valid_dl = vctk.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - speaker_map=speaker_map, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - train_dl=train_dl, - valid_dl=valid_dl, - speaker_map=speaker_map, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - # step per epoch - scheduler_g.step() - scheduler_d.step() - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - VctkTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/vctk/TTS/vits/transform.py b/egs/vctk/TTS/vits/transform.py deleted file mode 120000 index 962647408..000000000 --- a/egs/vctk/TTS/vits/transform.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py deleted file mode 100644 index 6c785d8c3..000000000 --- a/egs/vctk/TTS/vits/tts_datamodule.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class VctkTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--speakers", - type=Path, - default=Path("data/speakers.txt"), - help="Path to speakers.txt file.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=8, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") - - @lru_cache() - def speakers(self) -> Dict[str, int]: - logging.info("About to get speakers") - with open(self.args.speakers) as f: - speakers = {line.strip(): i for i, line in enumerate(f)} - return speakers diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py deleted file mode 120000 index 085e764b4..000000000 --- a/egs/vctk/TTS/vits/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py deleted file mode 120000 index 1f58cf6fe..000000000 --- a/egs/vctk/TTS/vits/vits.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py deleted file mode 120000 index 28f0a78ee..000000000 --- a/egs/vctk/TTS/vits/wavenet.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/README.md b/egs/voxpopuli/ASR/README.md deleted file mode 100644 index 92aa26464..000000000 --- a/egs/voxpopuli/ASR/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Readme - -This recipe contains data preparation for the -[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset -[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf). -At the moment, without model training. - - -## audio per language - -| language | Size | Hrs. untranscribed | Hrs. transcribed | -|----------|--------|--------------------|------------------| -| bg | 295G | 17.6K | - | -| cs | 308G | 18.7K | 62 | -| da | 233G | 13.6K | - | -| de | 379G | 23.2K | 282 | -| el | 305G | 17.7K | - | -| en | 382G | 24.1K | 543 | -| es | 362G | 21.4K | 166 | -| et | 179G | 10.6K | 3 | -| fi | 236G | 14.2K | 27 | -| fr | 376G | 22.8K | 211 | -| hr | 132G | 8.1K | 43 | -| hu | 297G | 17.7K | 63 | -| it | 361G | 21.9K | 91 | -| lt | 243G | 14.4K | 2 | -| lv | 217G | 13.1K | - | -| mt | 147G | 9.1K | - | -| nl | 322G | 19.0K | 53 | -| pl | 348G | 21.2K | 111 | -| pt | 300G | 17.5K | - | -| ro | 296G | 17.9K | 89 | -| sk | 201G | 12.1K | 35 | -| sl | 190G | 11.3K | 10 | -| sv | 272G | 16.3K | - | -| | | | | -| total | 6.3T | 384K | 1791 | - diff --git a/egs/voxpopuli/ASR/local/compute_fbank.py b/egs/voxpopuli/ASR/local/compute_fbank.py deleted file mode 100755 index b63e51f29..000000000 --- a/egs/voxpopuli/ASR/local/compute_fbank.py +++ /dev/null @@ -1,248 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2023 Brno University of Technology (authors: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of VoxPopuli dataset. - -Usage example: - - python3 ./local/compute_fbank.py \ - --src-dir data/fbank --output-dir data/fbank \ - --num-jobs 100 --num-workers 25 \ - --prefix "voxpopuli-${task}-${lang}" \ - --dataset train \ - --trim-to-supervisions True \ - --speed-perturb True - -It looks for raw CutSet in the directory data/fbank -located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`. - -The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats` -and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`. - -Typically, the number of workers is smaller than number of jobs -(see --num-jobs 100 --num-workers 25 in the example). -And, the number of jobs should be at least the number of workers (it's checked). -""" - -import argparse -import logging -import multiprocessing -import os -from concurrent.futures import ProcessPoolExecutor -from pathlib import Path - -import sentencepiece as spm -import torch -from filter_cuts import filter_cuts -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - is_caching_enabled, - set_caching_enabled, -) - -from icefall.utils import str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to the bpe.model. If not None, we will remove short and - long utterances before extracting features""", - ) - parser.add_argument( - "--src-dir", - type=str, - help="""Folder with the input manifest files.""", - default="data/manifests", - ) - parser.add_argument( - "--output-dir", - type=str, - help="""Folder with the output manifests (cuts) and feature files.""", - default="data/fbank", - ) - - parser.add_argument( - "--prefix", - type=str, - help="""Prefix of the manifest files.""", - default="", - ) - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank (train,test,dev).""", - default=None, - ) - - parser.add_argument( - "--num-jobs", - type=int, - help="""Number of jobs (i.e. files with extracted features)""", - default=50, - ) - parser.add_argument( - "--num-workers", - type=int, - help="""Number of parallel workers""", - default=10, - ) - parser.add_argument( - "--speed-perturb", - type=str2bool, - default=False, - help="""Enable speed perturbation for the set.""", - ) - parser.add_argument( - "--trim-to-supervisions", - type=str2bool, - default=False, - help="""Apply `trim-to-supervision` to cut set.""", - ) - - return parser.parse_args() - - -def compute_fbank_features(args: argparse.Namespace): - set_caching_enabled(True) # lhotse - - src_dir = Path(args.src_dir) - output_dir = Path(args.output_dir) - num_jobs = args.num_jobs - num_workers = min(args.num_workers, os.cpu_count()) - num_mel_bins = 80 - - bpe_model = args.bpe_model - if bpe_model: - logging.info(f"Loading {bpe_model}") - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - prefix = args.prefix # "ELEF_TRAIN" - dataset = args.dataset - suffix = "jsonl.gz" - - cuts_raw_filename = Path(f"{src_dir}/{prefix}_cuts_{dataset}_raw.{suffix}") - cuts_raw = CutSet.from_file(cuts_raw_filename) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - cuts_filename = Path(f"{prefix}_cuts_{dataset}.{suffix}") - if (output_dir / cuts_filename).is_file(): - logging.info(f"{output_dir/cuts_filename} already exists - skipping.") - return - - logging.info(f"Processing {output_dir/cuts_filename}") - cut_set = cuts_raw - - if bpe_model: - cut_set = filter_cuts(cut_set, sp) - - if args.speed_perturb: - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - - if args.trim_to_supervisions: - logging.info(f"About to `trim_to_supervisions()` {output_dir / cuts_filename}") - cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - else: - logging.info( - "Not doing `trim_to_supervisions()`, " - "to enable use --trim-to-supervision=True" - ) - - cut_set = cut_set.to_eager() # disallow lazy evaluation (sorting requires it) - cut_set = cut_set.sort_by_recording_id() # enhances AudioCache hit rate - - # We typically use `num_jobs=100, num_workers=20` - # - this is helpful for large databases - # - both values are configurable externally - assert num_jobs >= num_workers, (num_jobs, num_workers) - executor = ProcessPoolExecutor( - max_workers=num_workers, - mp_context=multiprocessing.get_context("spawn"), - initializer=set_caching_enabled, - initargs=(is_caching_enabled(),), - ) - - logging.info( - f"executor {executor} : num_workers {num_workers}, num_jobs {num_jobs}" - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir / prefix}-{dataset}_feats", - num_jobs=num_jobs, - executor=executor, - storage_type=LilcomChunkyWriter, - ) - - # correct small deviations of duration, caused by speed-perturbation - for cut in cut_set: - assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id) - duration_difference = abs(cut.supervisions[0].duration - cut.duration) - tolerance = 0.02 # 20ms - if duration_difference == 0.0: - pass - elif duration_difference <= tolerance: - logging.info( - "small mismatch of the supervision duration " - f"(Δt = {duration_difference*1000}ms), " - f"correcting : cut.duration {cut.duration} -> " - f"supervision {cut.supervisions[0].duration}" - ) - cut.supervisions[0].duration = cut.duration - else: - logging.error( - "mismatch of cut/supervision duration " - f"(Δt = {duration_difference*1000}ms) : " - f"cut.duration {cut.duration}, " - f"supervision {cut.supervisions[0].duration}" - ) - raise ValueError( - "mismatch of cut/supervision duration " - f"(Δt = {duration_difference*1000}ms)" - ) - - # store the cutset - logging.info(f"storing CutSet to : `{output_dir / cuts_filename}`") - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - logging.info(vars(args)) - - compute_fbank_features(args) diff --git a/egs/voxpopuli/ASR/local/compute_fbank_musan.py b/egs/voxpopuli/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/voxpopuli/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/display_manifest_statistics.py b/egs/voxpopuli/ASR/local/display_manifest_statistics.py deleted file mode 100755 index 36c99e126..000000000 --- a/egs/voxpopuli/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2023 Brno University of Technology (authors: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -Usage example: - python3 ./local/display_manifest_statistics.py data/fbank/*_cuts*.jsonl.gz - -See the function `remove_short_and_long_utt()` in transducer/train.py -for usage. - -""" - -import argparse - -from lhotse import load_manifest_lazy - - -def get_args(): - parser = argparse.ArgumentParser("Compute statistics for 'cuts' .jsonl.gz") - - parser.add_argument( - "filename", - help="data/fbank/imported_cuts_bison-train_trim.jsonl.gz", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - cuts = load_manifest_lazy(args.filename) - cuts.describe() - - -if __name__ == "__main__": - main() diff --git a/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py deleted file mode 100755 index 957267fe8..000000000 --- a/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Brno University of Technology (authors: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script computes durations of datasets from -the SupervisionSet manifests. - -Usage example: - - python3 ./local/duration_from_supervision_manifest.py \ - data/manifest/*_superivions*.jsonl.gz -""" - -import argparse -import gzip -import json -import logging -import re -import sys - - -def get_args(): - parser = argparse.ArgumentParser( - "Read the raw text from the 'supervisions.jsonl.gz'" - ) - - parser.add_argument( - "filename", - help="supervisions.jsonl.gz", - nargs="+", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - logging.info(vars(args)) - - total_duration = 0.0 - total_n_utts = 0 - - for fname in args.filename: - if fname == "-": - fd = sys.stdin - elif re.match(r".*\.jsonl\.gz$", fname): - fd = gzip.open(fname, mode="r") - else: - fd = open(fname, mode="r") - - fname_duration = 0.0 - n_utts = 0 - for line in fd: - js = json.loads(line) - fname_duration += js["duration"] - n_utts += 1 - - print( - f"Duration: {fname_duration/3600:7.2f} hours " - f"(eq. {fname_duration:7.0f} seconds, {n_utts} utts): {fname}" - ) - - if fd != sys.stdin: - fd.close() - - total_duration += fname_duration - total_n_utts += n_utts - - print( - f"Total duration: {total_duration/3600:7.2f} hours " - f"(eq. {total_duration:7.0f} seconds)" - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/voxpopuli/ASR/local/filter_cuts.py b/egs/voxpopuli/ASR/local/filter_cuts.py deleted file mode 120000 index 27aca1729..000000000 --- a/egs/voxpopuli/ASR/local/filter_cuts.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/prepare_lang_bpe.py b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/voxpopuli/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py deleted file mode 100755 index 4032537db..000000000 --- a/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py +++ /dev/null @@ -1,178 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) -# 2023 Brno University of Technology (author: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Preprocess the database. -- Convert RecordingSet and SupervisionSet to CutSet. -- Apply text normalization to the transcripts. - - We take renormalized `orig_text` as `text` transcripts. - - The text normalization is separating punctuation from words. - - Also we put capital letter to the beginning of a sentence. - -The script is inspired in: - `egs/commonvoice/ASR/local/preprocess_commonvoice.py` - -Usage example: - python3 ./local/preprocess_voxpopuli.py \ - --task asr --lang en - -""" - -import argparse -import logging -from pathlib import Path -from typing import Optional - -from lhotse import CutSet -from lhotse.recipes.utils import read_manifests_if_cached - -# from local/ -from separate_punctuation import separate_punctuation -from uppercase_begin_of_sentence import UpperCaseBeginOfSentence - -from icefall.utils import str2bool - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - default=None, - ) - - parser.add_argument( - "--task", - type=str, - help="""Task of VoxPopuli""", - default="asr", - ) - - parser.add_argument( - "--lang", - type=str, - help="""Language of VoxPopuli""", - required=True, - ) - - parser.add_argument( - "--use-original-text", - type=str2bool, - help="""Use 'original_text' from the annoattaion file, - otherwise 'normed_text' will be used - (see `data/manifests/${task}_${lang}.tsv.gz`). - """, - default=False, - ) - - return parser.parse_args() - - -def normalize_text(utt: str) -> str: - utt = UpperCaseBeginOfSentence().process_line_text(separate_punctuation(utt)) - return utt - - -def preprocess_voxpopuli( - task: str, - language: str, - dataset: Optional[str] = None, - use_original_text: bool = False, -): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - output_dir.mkdir(exist_ok=True) - - if dataset is None: - dataset_parts = ( - "dev", - "test", - "train", - ) - else: - dataset_parts = dataset.split(" ", -1) - - logging.info("Loading manifest") - prefix = f"voxpopuli-{task}-{language}" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - suffix=suffix, - prefix=prefix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - for partition, m in manifests.items(): - logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" - if raw_cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping") - continue - - if use_original_text: - logging.info("Using 'original_text' from the annotation file.") - logging.info(f"Normalizing text in {partition}") - for sup in m["supervisions"]: - # `orig_text` includes punctuation and true-case - orig_text = str(sup.custom["orig_text"]) - # we replace `text` by normalized `orig_text` - sup.text = normalize_text(orig_text) - else: - logging.info("Using 'normed_text' from the annotation file.") - - # remove supervisions with empty 'text' - m["supervisions"] = m["supervisions"].filter(lambda sup: len(sup.text) > 0) - - # Create cut manifest with long-recordings. - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ).resample(16000) - - # Store the cut set incl. the resampling. - logging.info(f"Saving to {raw_cuts_path}") - cut_set.to_file(raw_cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - preprocess_voxpopuli( - task=args.task, - language=args.lang, - dataset=args.dataset, - use_original_text=args.use_original_text, - ) - logging.info("Done") - - -if __name__ == "__main__": - main() diff --git a/egs/voxpopuli/ASR/local/separate_punctuation.py b/egs/voxpopuli/ASR/local/separate_punctuation.py deleted file mode 100755 index 706d6fcd5..000000000 --- a/egs/voxpopuli/ASR/local/separate_punctuation.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Brno University of Technology (authors: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script chops the punctuation as standalone tokens. -Example: - input: "This is fine. Yes, you are right." - output: "This is fine . Yes , you are right ." - -The script also handles exceptions in a hard-coded fashion. - -(same functionality could be done with `nltk.tokenize.word_tokenize()`, - but that would be an extra dependency) - -It can be used as a module, or as an executable script. - -Usage example #1: - `from separate_punctuation import separate_punctuation` - -Usage example #2: -``` - python3 ./local/separate_punctuation.py \ - --ignore-columns 1 \ - < ${kaldi_data}/text -``` -""" - -import re -import sys -from argparse import ArgumentParser - - -def separate_punctuation(text: str) -> str: - """ - Text filtering function for separating punctuation. - - Example: - input: "This is fine. Yes, you are right." - output: "This is fine . Yes , you are right ." - - The exceptions for which the punctuation is - not splitted are hard-coded. - """ - - # remove non-desired punctuation symbols - text = re.sub('["„“«»]', "", text) - - # separate [,.!?;] punctuation from words by space - text = re.sub(r"(\w)([,.!?;])", r"\1 \2", text) - text = re.sub(r"([,.!?;])(\w)", r"\1 \2", text) - - # split to tokens - tokens = text.split() - tokens_out = [] - - # re-join the special cases of punctuation - for ii, tok in enumerate(tokens): - # no rewriting for 1st and last token - if ii > 0 and ii < len(tokens) - 1: - # **RULES ADDED FOR CZECH COMMON VOICE** - - # fix "27 . dubna" -> "27. dubna", but keep punctuation separate, - if tok == "." and tokens[ii - 1].isdigit() and tokens[ii + 1].islower(): - tokens_out[-1] = tokens_out[-1] + "." - continue - - # fix "resp . pak" -> "resp. pak" - if tok == "." and tokens[ii - 1].isalpha() and tokens[ii + 1].islower(): - tokens_out[-1] = tokens_out[-1] + "." - continue - - # **RULES ADDED FOR ENGLISH COMMON VOICE** - - # fix "A ." -> "A." - if tok == "." and re.match(r"^[A-Z]S", tokens[ii - 1]): - tokens_out[-1] = tokens_out[-1] + "." - continue - - # fix "Mr ." -> "Mr." - exceptions = set(["Mr", "Mrs", "Ms"]) - if tok == "." and tokens[ii - 1] in exceptions: - tokens_out[-1] = tokens_out[-1] + "." - continue - - tokens_out.append(tok) - - return " ".join(tokens_out) - - -def get_args(): - parser = ArgumentParser( - description="Separate punctuation from words: 'hello.' -> 'hello .'" - ) - parser.add_argument( - "--ignore-columns", type=int, default=1, help="skip number of initial columns" - ) - return parser.parse_args() - - -def main(): - args = get_args() - - max_split = args.ignore_columns - - while True: - line = sys.stdin.readline() - if not line: - break - - *key, text = line.strip().split(maxsplit=max_split) - text_norm = separate_punctuation(text) - - print(" ".join(key), text_norm) - - -if __name__ == "__main__": - main() diff --git a/egs/voxpopuli/ASR/local/text_from_manifest.py b/egs/voxpopuli/ASR/local/text_from_manifest.py deleted file mode 100755 index d9ab53b5a..000000000 --- a/egs/voxpopuli/ASR/local/text_from_manifest.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Brno University of Technology (authors: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`. - -Usage example: - python3 ./local/text_from_manifest.py \ - data/manifests/voxpopuli-asr-en_supervisions_dev.jsonl.gz -""" - -import argparse -import gzip -import json - - -def get_args(): - parser = argparse.ArgumentParser( - "Read the raw text from the 'supervisions.jsonl.gz'" - ) - parser.add_argument("filename", help="supervisions.jsonl.gz") - return parser.parse_args() - - -def main(): - args = get_args() - - with gzip.open(args.filename, mode="r") as fd: - for line in fd: - js = json.loads(line) - if "text" in js: - print(js["text"]) # supervisions.jsonl.gz - elif "supervisions" in js: - for s in js["supervisions"]: - print(s["text"]) # cuts.jsonl.gz - else: - raise Exception(f"Unknown jsonl format of {args.filename}") - - -if __name__ == "__main__": - main() diff --git a/egs/voxpopuli/ASR/local/train_bpe_model.py b/egs/voxpopuli/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/voxpopuli/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py deleted file mode 100755 index 8e9de905f..000000000 --- a/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Brno University of Technology (authors: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script introduces initial capital letter at the beginning of a sentence. -It can be used as a module, or as an executable script. - -Usage example #1: - `from uppercase_begin_of_sentence import UpperCaseBeginOfSentence` - -Usage example #2: -``` - python3 ./local/uppercase_begin_of_sentence.py \ - --ignore-columns 1 \ - < ${kaldi_data}/text -``` -""" - -import re -import sys -from argparse import ArgumentParser - - -class UpperCaseBeginOfSentence: - """ - This class introduces initial capital letter at the beginning of a sentence. - Capital letter is used, if previous symbol was punctuation token from - `set([".", "!", "?"])`. - - The punctuation as previous token is memorized also across - `process_line_text()` calls. - """ - - def __init__(self): - # The 1st word will have Title-case - # This variable transfers context from previous line - self.prev_token_is_punct = True - - def process_line_text(self, line_text: str) -> str: - """ - It is assumed that punctuation in `line_text` was already separated, - example: "This is fine . Yes , you are right ." - """ - - words = line_text.split() - punct_set = set([".", "!", "?"]) - - for ii, w in enumerate(words): - # punctuation ? - if w in punct_set: - self.prev_token_is_punct = True - continue - - # change case of word... - if self.prev_token_is_punct: - if re.match("<", w): - continue # skip - # apply Title-case only on lowercase words. - if w.islower(): - words[ii] = w.title() - # change state - self.prev_token_is_punct = False - - line_text_uc = " ".join(words) - - return line_text_uc - - -def get_args(): - parser = ArgumentParser( - description="Put upper-case at the beginning of a sentence." - ) - parser.add_argument( - "--ignore-columns", type=int, default=4, help="skip number of initial columns" - ) - return parser.parse_args() - - -def main(): - args = get_args() - - uc_bos = UpperCaseBeginOfSentence() - max_split = args.ignore_columns - - while True: - line = sys.stdin.readline() - if not line: - break - line = line.strip() - - if len(line.split()) > 1: - *key, text = line.strip().split(maxsplit=max_split) # parse, - text_uc = uc_bos.process_line_text(text) # process, - print(" ".join(key), text_uc) # print, - else: - print(line) - - -if __name__ == "__main__": - main() diff --git a/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/validate_cutset_manifest.py b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py deleted file mode 100755 index 4659aa9cd..000000000 --- a/egs/voxpopuli/ASR/local/validate_cutset_manifest.py +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# 2023 Brno University of Technology (authors: Karel Veselý) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut -- Supervision time bounds are within Cut time bounds -- Duration of Cut and Superivion are equal - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz - -(Based on: `librispeech/ASR/local/validate_manifest.py`) -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut -from lhotse.dataset.speech_recognition import validate_for_asr - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "cutset_manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def validate_one_supervision_per_cut(c: Cut): - if len(c.supervisions) != 1: - raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") - - -def validate_supervision_and_cut_time_bounds(c: Cut): - tol = 2e-3 # same tolerance as in 'validate_for_asr()' - s = c.supervisions[0] - - # Supervision start time is relative to Cut ... - # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html - if s.start < -tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} must not be negative." - ) - if s.start > tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} " - "is not at the beginning of the Cut. " - "Please apply `lhotse cut trim-to-supervisions`." - ) - if c.start + s.end > c.end + tol: - raise ValueError( - f"{c.id}: Supervision end time {c.start+s.end} is larger " - f"than cut end time {c.end}" - ) - - if s.duration != c.duration: - raise ValueError( - f"{c.id}: Cut duration {c.duration} and supervision duration " - f"{s.duration} must be the same.\n" - f"The difference causes problems in the training code : " - f"+/- 1 frame in `x`, `x_lens` in `Zipformer::forward()`.\n" - f"Did you forget to apply `trim_to_supervisions()` ?" - ) - - -def main(): - args = get_args() - - manifest = args.cutset_manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet) - - try: - for c in cut_set: - validate_one_supervision_per_cut(c) - validate_supervision_and_cut_time_bounds(c) - - # Validation from K2 training - # - checks supervision start is 0 - # - checks supervision.duration is not longer than cut.duration - # - there is tolerance 2ms - validate_for_asr(cut_set) - except BaseException as e: - logging.error(str(e)) - raise - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/voxpopuli/ASR/prepare.sh b/egs/voxpopuli/ASR/prepare.sh deleted file mode 100755 index 7cddad756..000000000 --- a/egs/voxpopuli/ASR/prepare.sh +++ /dev/null @@ -1,257 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -euxo pipefail - -nj=20 -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/voxpopuli/raw_audios/$lang/$year -# This directory contains *.ogg files with audio downloaded and extracted from archives: -# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar -# -# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder -# as part of `lhotse prepare voxpopuli` from: -# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download -#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT - -musan_dir=${dl_dir}/musan -#musan_dir=/mnt/matylda2/data/MUSAN # BUT - -# Choose value from ASR_LANGUAGES: -# -# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr", -# "sk", "sl", "et", "lt" ] -# -# See ASR_LANGUAGES in: -# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4 -lang=en - -task=asr - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/${lang}/lang_bpe_xxx, -# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - # 5000 - # 2000 - # 1000 - 500 -) - -# All files generated by this script are saved in "data/${lang}". -# You can safely remove "data/${lang}" and rerun this script to regenerate it. -mkdir -p data/${lang} - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" -log "musan_dir: $musan_dir" -log "task: $task, lang: $lang" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/$release, - # you can create a symlink - # - # ln -sfv /path/to/$release $dl_dir/$release - # - if [ ! -d $dl_dir/voxpopuli/raw_audios/${lang} ]; then - lhotse download voxpopuli --subset $lang $dl_dir/voxpopuli - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $musan_dir/musan ]; then - lhotse download musan $musan_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare VoxPopuli manifest" - # We assume that you have downloaded the VoxPopuli corpus - # to $dl_dir/voxpopuli - if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then - # Warning : it requires Internet connection (it downloads transcripts to ${tmpdir}) - lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests - touch data/manifests/.voxpopuli-${task}-${lang}.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - mkdir -p data/manifests - if [ ! -e data/manifests/.musan.done ]; then - #lhotse prepare musan $dl_dir/musan data/manifests - lhotse prepare musan $musan_dir/musan data/manifests - touch data/manifests/.musan.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Preprocess VoxPopuli manifest" - mkdir -p data/fbank - if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete ]; then - # recordings + supervisions -> cutset - ./local/preprocess_voxpopuli.py --task $task --lang $lang \ - --use-original-text True - touch data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for dev and test subsets of VoxPopuli" - mkdir -p data/fbank - for dataset in "dev" "test"; do - if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then - ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ - --num-jobs 50 --num-workers ${nj} \ - --prefix "voxpopuli-${task}-${lang}" \ - --dataset ${dataset} \ - --trim-to-supervisions True - touch data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done - fi - done -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for train set of VoxPopuli" - if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then - ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ - --num-jobs 100 --num-workers ${nj} \ - --prefix "voxpopuli-${task}-${lang}" \ - --dataset train \ - --trim-to-supervisions True \ - --speed-perturb True - touch data/fbank/.voxpopuli-${task}-${lang}-train.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Validate fbank manifests for VoxPopuli" - for dataset in "dev" "test" "train"; do - mkdir -p data/fbank/log/ - ./local/validate_cutset_manifest.py \ - data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \ - 2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compute fbank for musan" - mkdir -p data/fbank - if [ ! -e data/fbank/.musan.done ]; then - ./local/compute_fbank_musan.py - touch data/fbank/.musan.done - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size}_${lang} - mkdir -p $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - file=$( - find "data/fbank/voxpopuli-${task}-${lang}_cuts_train.jsonl.gz" - ) - local/text_from_manifest.py $file >$lang_dir/transcript_words.txt - # gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt - - # Ensure space only appears once - #sed -i 's/\t/ /g' $lang_dir/transcript_words.txt - #sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/words.txt ]; then - cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_dir/words.txt - (echo '!SIL'; echo ''; echo ''; ) | - cat - $lang_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_dir/words || exit 1; - mv $lang_dir/words $lang_dir/words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done -fi diff --git a/egs/voxpopuli/ASR/shared b/egs/voxpopuli/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/voxpopuli/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/wenetspeech/ASR/README.md b/egs/wenetspeech/ASR/README.md deleted file mode 100644 index 44e631b4a..000000000 --- a/egs/wenetspeech/ASR/README.md +++ /dev/null @@ -1,20 +0,0 @@ - -# Introduction - -This recipe includes some different ASR models trained with WenetSpeech. - -[./RESULTS.md](./RESULTS.md) contains the latest results. - -# Transducers - -There are various folders containing the name `transducer` in this folder. -The following table lists the differences among them. - -| | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|-----------------------------| -| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | -| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | - -The decoder in `transducer_stateless` is modified from the paper -[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). -We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/wenetspeech/ASR/RESULTS.md b/egs/wenetspeech/ASR/RESULTS.md deleted file mode 100644 index 1a0e0681f..000000000 --- a/egs/wenetspeech/ASR/RESULTS.md +++ /dev/null @@ -1,294 +0,0 @@ -## Results - -### WenetSpeech char-based training results (Non-streaming and streaming) on zipformer model - -This is the [pull request](https://github.com/k2-fsa/icefall/pull/1130) in icefall. - -#### Non-streaming - -Best results (num of params : ~76M): - -Type | Greedy(dev & net & meeting) | Beam search(dev & net & meeting) |   --- | -- | -- | -- -Non-streaming | 7.36 & 7.65 & 12.43 | 7.32 & 7.61 & 12.35 | --epoch=12 - -The training command: - -``` -./zipformer/train.py \ - --world-size 6 \ - --num-epochs 12 \ - --use-fp16 1 \ - --max-duration 450 \ - --training-subset L \ - --lr-epochs 1.5 \ - --context-size 2 \ - --exp-dir zipformer/exp_L_context_2 \ - --causal 0 \ - --num-workers 8 -``` - -Listed best results for each epoch below: - -Epoch | Greedy search(dev & net & meeting) | Modified beam search(dev & net & meeting) |   --- | -- | -- | -- -4 | 7.83 & 8.86 &13.73 | 7.75 & 8.81 & 13.67 | avg=1;blank-penalty=2 -5 | 7.75 & 8.46 & 13.38 | 7.68 & 8.41 & 13.27 | avg=1;blank-penalty=2 -6 | 7.72 & 8.19 & 13.16 | 7.62 & 8.14 & 13.06 | avg=1;blank-penalty=2 -7 | 7.59 & 8.08 & 12.97 | 7.53 & 8.01 & 12.87 | avg=2;blank-penalty=2 -8 | 7.68 & 7.87 & 12.96 | 7.61 & 7.81 & 12.88 | avg=1;blank-penalty=2 -9 | 7.57 & 7.77 & 12.87 | 7.5 & 7.71 & 12.77 | avg=1;blank-penalty=2 -10 | 7.45 & 7.7 & 12.69 | 7.39 & 7.63 & 12.59 | avg=2;blank-penalty=2 -11 | 7.35 & 7.67 & 12.46 | 7.31 & 7.63 & 12.43 | avg=3;blank-penalty=2 -12 | 7.36 & 7.65 & 12.43 | 7.32 & 7.61 & 12.35 | avg=4;blank-penalty=2 - -The pre-trained model is available here : https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615 - - -#### Streaming - -Best results (num of params : ~76M): - -Type | Greedy(dev & net & meeting) | Beam search(dev & net & meeting) |   --- | -- | -- | -- -Streaming | 8.45 & 9.89 & 16.46 | 8.21 & 9.77 & 16.07 | --epoch=12; --chunk-size=16; --left-context-frames=256 -Streaming | 8.0 & 9.0 & 15.11 | 7.84 & 8.94 & 14.92 | --epoch=12; --chunk-size=32; --left-context-frames=256 - -The training command: - -``` -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 12 \ - --use-fp16 1 \ - --max-duration 450 \ - --training-subset L \ - --lr-epochs 1.5 \ - --context-size 2 \ - --exp-dir zipformer/exp_L_causal_context_2 \ - --causal 1 \ - --num-workers 8 -``` - -Best results for each epoch (--chunk-size=16; --left-context-frames=128) - -Epoch | Greedy search(dev & net & meeting) | Modified beam search(dev & net & meeting) |   --- | -- | -- | -- -6 | 9.14 & 10.75 & 18.15 | 8.79 & 10.54 & 17.64 | avg=1;blank-penalty=1.5 -7 | 9.11 & 10.61 & 17.86 | 8.8 & 10.42 & 17.29 | avg=1;blank-penalty=1.5 -8 | 8.89 & 10.32 & 17.44 | 8.59 & 10.09 & 16.9 | avg=1;blank-penalty=1.5 -9 | 8.86 & 10.11 & 17.35 | 8.55 & 9.87 & 16.76 | avg=1;blank-penalty=1.5 -10 | 8.66 & 10.0 & 16.94 | 8.39 & 9.83 & 16.47 | avg=2;blank-penalty=1.5 -11 | 8.58 & 9.92 & 16.67 | 8.32 & 9.77 & 16.27 | avg=3;blank-penalty=1.5 -12 | 8.45 & 9.89 & 16.46 | 8.21 & 9.77 & 16.07 | avg=4;blank-penalty=1.5 - -The pre-trained model is available here: https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 - - -### WenetSpeech char-based training results (offline and streaming) (Pruned Transducer 5) - -#### 2022-07-22 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/447. - -When training with the L subset, the CERs are - -**Offline**: -|decoding-method| epoch | avg | use-averaged-model | DEV | TEST-NET | TEST-MEETING| -|-- | -- | -- | -- | -- | -- | --| -|greedy_search | 4 | 1 | True | 8.22 | 9.03 | 14.54| -|modified_beam_search | 4 | 1 | True | **8.17** | **9.04** | **14.44**| -|fast_beam_search | 4 | 1 | True | 8.29 | 9.00 | 14.93| - -The offline training command for reproducing is given below: -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless5/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless5/exp_L_offline \ - --world-size 8 \ - --num-epochs 15 \ - --start-epoch 2 \ - --max-duration 120 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --average-period 1000 \ - --training-subset L -``` - -The tensorboard training log can be found at https://tensorboard.dev/experiment/SvnN2jfyTB2Hjqu22Z7ZoQ/#scalars . - - -A pre-trained offline model and decoding logs can be found at - -**Streaming**: -|decoding-method| epoch | avg | use-averaged-model | DEV | TEST-NET | TEST-MEETING| -|--|--|--|--|--|--|--| -| greedy_search | 7| 1| True | 8.78 | 10.12 | 16.16 | -| modified_beam_search | 7| 1| True| **8.53**| **9.95** | **15.81** | -| fast_beam_search | 7 | 1| True | 9.01 | 10.47 | 16.28 | - -The streaming training command for reproducing is given below: -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless5/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless5/exp_L_streaming \ - --world-size 8 \ - --num-epochs 15 \ - --start-epoch 1 \ - --max-duration 140 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --average-period 1000 \ - --training-subset L \ - --dynamic-chunk-training True \ - --causal-convolution True \ - --short-chunk-size 25 \ - --num-left-chunks 4 -``` - -The tensorboard training log can be found at https://tensorboard.dev/experiment/E2NXPVflSOKWepzJ1a1uDQ/#scalars . - - -A pre-trained offline model and decoding logs can be found at - -### WenetSpeech char-based training results (Pruned Transducer 2) - -#### 2022-05-19 - -Using the codes from this PR https://github.com/k2-fsa/icefall/pull/349. - -When training with the L subset, the CERs are - -| | dev | test-net | test-meeting | comment | -|------------------------------------|-------|----------|--------------|------------------------------------------| -| greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 | -| modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 | -| fast beam search (1best) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 | -| fast beam search (nbest) | 9.82 | 10.98 | 16.37 | --epoch 10, --avg 2, --max-duration 600 | -| fast beam search (nbest oracle) | 6.88 | 7.18 | 11.77 | --epoch 10, --avg 2, --max-duration 600 | -| fast beam search (nbest LG, ngram_lm_scale=0.35) | 8.83 | 9.88 | 15.47 | --epoch 10, --avg 2, --max-duration 600 | - -The training command for reproducing is given below: - -``` -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 15 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --training-subset L -``` - -The tensorboard training log can be found at -https://tensorboard.dev/experiment/wM4ZUNtASRavJx79EOYYcg/#scalars - -The decoding command is: -``` -epoch=10 -avg=2 - -## greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method greedy_search - -## modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -## fast beam search (1best) -./pruned_transducer_stateless2/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 - -## fast beam search (nbest) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -## fast beam search (nbest oracle WER) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -## fast beam search (with LG) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --ngram-lm-scale 0.35 \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -``` - -When training with the M subset, the CERs are - -| | dev | test-net | test-meeting | comment | -|------------------------------------|--------|-----------|---------------|-------------------------------------------| -| greedy search | 10.40 | 11.31 | 19.64 | --epoch 29, --avg 11, --max-duration 100 | -| modified beam search (beam size 4) | 9.85 | 11.04 | 18.20 | --epoch 29, --avg 11, --max-duration 100 | -| fast beam search (set as default) | 10.18 | 11.10 | 19.32 | --epoch 29, --avg 11, --max-duration 1500 | - - -When training with the S subset, the CERs are - -| | dev | test-net | test-meeting | comment | -|------------------------------------|--------|-----------|---------------|-------------------------------------------| -| greedy search | 19.92 | 25.20 | 35.35 | --epoch 29, --avg 24, --max-duration 100 | -| modified beam search (beam size 4) | 18.62 | 23.88 | 33.80 | --epoch 29, --avg 24, --max-duration 100 | -| fast beam search (set as default) | 19.31 | 24.41 | 34.87 | --epoch 29, --avg 24, --max-duration 1500 | - - -A pre-trained model and decoding logs can be found at diff --git a/egs/wenetspeech/ASR/finetune.sh b/egs/wenetspeech/ASR/finetune.sh deleted file mode 100755 index 8559780e9..000000000 --- a/egs/wenetspeech/ASR/finetune.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -# This is an example script for fine-tuning. Here, we fine-tune a model trained -# on WenetSpeech on Aishell. The model used for fine-tuning is -# pruned_transducer_stateless2 (zipformer). If you want to fine-tune model -# from another recipe, you can adapt ./pruned_transducer_stateless2/finetune.py -# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues. - -# We assume that you have already prepared the Aishell manfiest&features under ./data. -# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/prepare.sh. - -. shared/parse_options.sh || exit 1 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download Pre-trained model" - - # clone from huggingface - git lfs install - git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 - -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Start fine-tuning" - - # The following configuration of lr schedule should work well - # You may also tune the following parameters to adjust learning rate schedule - initial_lr=0.0001 - lr_epochs=100 - lr_batches=100000 - - # We recommend to start from an averaged model - finetune_ckpt=icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/pretrained_epoch_10_avg_2.pt - lang_dir=icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char - export CUDA_VISIBLE_DEVICES="0,1" - - ./pruned_transducer_stateless2/finetune.py \ - --world-size 2 \ - --master-port 18180 \ - --num-epochs 15 \ - --context-size 2 \ - --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \ - --initial-lr $initial_lr \ - --lr-epochs $lr_epochs \ - --lr-batches $lr_batches \ - --lang-dir $lang_dir \ - --do-finetune True \ - --finetune-ckpt $finetune_ckpt \ - --max-duration 200 -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Decoding" - - epoch=4 - avg=4 - - for m in greedy_search modified_beam_search; do - python pruned_transducer_stateless2/decode_aishell.py \ - --epoch $epoch \ - --avg $avg \ - --context-size 2 \ - --beam-size 4 \ - --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \ - --max-duration 400 \ - --decoding-method $m - done -fi diff --git a/egs/wenetspeech/ASR/local/compile_lg.py b/egs/wenetspeech/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/wenetspeech/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/compute_fbank_musan.py b/egs/wenetspeech/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/wenetspeech/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py deleted file mode 100755 index ac4e92ec5..000000000 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - return parser - - -def compute_fbank_wenetspeech_dev_test(args): - in_out_dir = Path("data/fbank") - # number of workers in dataloader - num_workers = 42 - - # number of seconds in a batch - batch_duration = 600 - - subsets = ("DEV", "TEST_NET", "TEST_MEETING") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - if args.whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") - ) - else: - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - - logging.info(f"device: {device}") - - for partition in subsets: - cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{in_out_dir}/feats_{partition}", - num_workers=num_workers, - batch_duration=batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - compute_fbank_wenetspeech_dev_test(args) - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py deleted file mode 100755 index 804a302bd..000000000 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ /dev/null @@ -1,211 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from datetime import datetime -from pathlib import Path - -import torch -from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig, - CutSet, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, - set_audio_duration_mismatch_tolerance, - set_caching_enabled, -) - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--training-subset", - type=str, - default="L", - help="The training subset for computing fbank feature.", - ) - - parser.add_argument( - "--num-workers", - type=int, - default=20, - help="Number of dataloading workers used for reading the audio.", - ) - - parser.add_argument( - "--batch-duration", - type=float, - default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", - ) - - parser.add_argument( - "--num-splits", - type=int, - required=True, - help="The number of splits of the L subset", - ) - - parser.add_argument( - "--start", - type=int, - default=0, - help="Process pieces starting from this number (included).", - ) - - parser.add_argument( - "--stop", - type=int, - default=-1, - help="Stop processing pieces until this number (excluded).", - ) - - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - - parser.add_argument( - "--output-dir-prefix", - type=str, - default="", - help="Prefix of the output directory.", - ) - return parser - - -def compute_fbank_wenetspeech_splits(args): - subset = args.training_subset - subset = str(subset) - num_splits = args.num_splits - output_dir = f"data/fbank/{subset}_split_{num_splits}" - output_dir = Path(output_dir) - output_dir = Path(args.output_dir_prefix) / output_dir - assert output_dir.exists(), f"{output_dir} does not exist!" - - num_digits = len(str(num_splits)) - - start = args.start - stop = args.stop - if stop < start: - stop = num_splits - - stop = min(stop, num_splits) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - if args.whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) - ) - # extractor = KaldifeatWhisperFbank(KaldifeatWhisperFbankConfig(num_filters=args.num_mel_bins, device=device)) - else: - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - logging.info(f"device: {device}") - - set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance - set_caching_enabled(False) - # with get_executor() as ex: # Initialize the executor only once. - for i in range(start, stop): - idx = f"{i}".zfill(num_digits) - logging.info(f"Processing {i+1}/{num_splits}") - - cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = output_dir / f"cuts_{subset}_raw.{idx}.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/feats_{subset}_{idx}", - num_workers=args.num_workers, - batch_duration=args.batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -def main(): - now = datetime.now() - date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - - log_filename = "log-compute_fbank_wenetspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - log_filename = f"{log_filename}-{date_time}" - - logging.basicConfig( - filename=log_filename, - format=formatter, - level=logging.INFO, - filemode="w", - ) - - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(logging.Formatter(formatter)) - logging.getLogger("").addHandler(console) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - compute_fbank_wenetspeech_splits(args) - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/display_manifest_statistics.py b/egs/wenetspeech/ASR/local/display_manifest_statistics.py deleted file mode 100644 index 36e4ac5c3..000000000 --- a/egs/wenetspeech/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -See the function `remove_short_and_long_utt()` -in ../../../librispeech/ASR/transducer/train.py -for usage. -""" - - -from lhotse import load_manifest_lazy - - -def main(): - paths = [ - "./data/fbank/cuts_S.jsonl.gz", - "./data/fbank/cuts_M.jsonl.gz", - "./data/fbank/cuts_L.jsonl.gz", - "./data/fbank/cuts_DEV.jsonl.gz", - "./data/fbank/cuts_TEST_NET.jsonl.gz", - "./data/fbank/cuts_TEST_MEETING.jsonl.gz", - ] - - for path in paths: - print(f"Starting display the statistics for {path}") - cuts = load_manifest_lazy(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -Starting display the statistics for ./data/fbank/cuts_L.jsonl.gz - -Cuts count: 43874235 -Total duration (hours): 30217.3 -Speech duration (hours): 30217.3 (100.0%) -*** -Duration statistics (seconds): -mean 2.5 -std 1.7 -min 0.2 -25% 1.4 -50% 2.0 -75% 3.0 -99% 8.4 -99.5% 9.1 -99.9% 15.4 -max 405.1 - -Starting display the statistics for ./data/fbank/cuts_S.jsonl.gz -Duration statistics (seconds): -mean 2.4 -std 1.8 -min 0.2 -25% 1.4 -50% 2.0 -75% 2.9 -99% 8.0 -99.5% 8.7 -99.9% 11.9 -max 405.1 - -Starting display the statistics for ./data/fbank/cuts_M.jsonl.gz -Cuts count: 4543341 -Total duration (hours): 3021.1 -Speech duration (hours): 3021.1 (100.0%) -*** -Duration statistics (seconds): -mean 2.4 -std 1.6 -min 0.2 -25% 1.4 -50% 2.0 -75% 2.9 -99% 8.0 -99.5% 8.8 -99.9% 12.1 -max 405.1 - -Starting display the statistics for ./data/fbank/cuts_DEV.jsonl.gz -Cuts count: 13825 -Total duration (hours): 20.0 -Speech duration (hours): 20.0 (100.0%) -*** -Duration statistics (seconds): -mean 5.2 -std 2.2 -min 1.0 -25% 3.3 -50% 4.9 -75% 7.0 -99% 9.6 -99.5% 9.8 -99.9% 10.0 -max 10.0 - -Starting display the statistics for ./data/fbank/cuts_TEST_NET.jsonl.gz -Cuts count: 24774 -Total duration (hours): 23.1 -Speech duration (hours): 23.1 (100.0%) -*** -Duration statistics (seconds): -mean 3.4 -std 2.6 -min 0.1 -25% 1.4 -50% 2.4 -75% 4.8 -99% 13.1 -99.5% 14.5 -99.9% 18.5 -max 33.3 - -Starting display the statistics for ./data/fbank/cuts_TEST_MEETING.jsonl.gz -Cuts count: 8370 -Total duration (hours): 15.2 -Speech duration (hours): 15.2 (100.0%) -*** -Duration statistics (seconds): -mean 6.5 -std 3.5 -min 0.8 -25% 3.7 -50% 5.8 -75% 8.8 -99% 15.2 -99.5% 16.0 -99.9% 18.8 -max 24.6 - -""" diff --git a/egs/wenetspeech/ASR/local/fix_manifest.py b/egs/wenetspeech/ASR/local/fix_manifest.py deleted file mode 100644 index b2632bd52..000000000 --- a/egs/wenetspeech/ASR/local/fix_manifest.py +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 author: Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import logging - -from lhotse import CutSet, load_manifest_lazy - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--fixed-transcript-path", - type=str, - default="data/fbank/text.fix", - help=""" - See https://github.com/wenet-e2e/WenetSpeech/discussions/54 - wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix - """, - ) - - parser.add_argument( - "--manifest-dir", - type=str, - default="data/fbank/", - help="Directory to store the manifest files", - ) - - parser.add_argument( - "--training-subset", - type=str, - default="L", - help="The training subset for wenetspeech.", - ) - - return parser - - -def load_fixed_text(fixed_text_path): - """ - fixed text format - X0000016287_92761015_S00001 我是徐涛 - X0000016287_92761015_S00002 狄更斯的PICK WEEK PAPERS斯 - load into a dict - """ - fixed_text_dict = {} - with open(fixed_text_path, "r") as f: - for line in f: - cut_id, text = line.strip().split(" ", 1) - fixed_text_dict[cut_id] = text - return fixed_text_dict - - -def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path): - with CutSet.open_writer(fixed_manifest_path) as manifest_writer: - fixed_item = 0 - for i, cut in enumerate(manifest): - if i % 10000 == 0: - logging.info(f"Processing cut {i}, fixed {fixed_item}") - cut_id_orgin = cut.id - if cut_id_orgin.endswith("_sp0.9"): - cut_id = cut_id_orgin[:-6] - elif cut_id_orgin.endswith("_sp1.1"): - cut_id = cut_id_orgin[:-6] - else: - cut_id = cut_id_orgin - if cut_id in fixed_text_dict: - assert ( - len(cut.supervisions) == 1 - ), f"cut {cut_id} has {len(cut.supervisions)} supervisions" - if cut.supervisions[0].text != fixed_text_dict[cut_id]: - logging.info( - f"Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}" - ) - cut.supervisions[0].text = fixed_text_dict[cut_id] - fixed_item += 1 - manifest_writer.write(cut) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - fixed_text_path = args.manifest_dir + "text.fix" - fixed_text_dict = load_fixed_text(fixed_text_path) - logging.info(f"Loaded {len(fixed_text_dict)} fixed texts") - - dev_manifest_path = args.manifest_dir + "cuts_DEV.jsonl.gz" - fixed_dev_manifest_path = args.manifest_dir + "cuts_DEV_fixed.jsonl.gz" - logging.info(f"Loading dev manifest from {dev_manifest_path}") - cuts_dev_manifest = load_manifest_lazy(dev_manifest_path) - fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path) - logging.info(f"Fixed dev manifest saved to {fixed_dev_manifest_path}") - - manifest_path = args.manifest_dir + f"cuts_{args.training_subset}.jsonl.gz" - fixed_manifest_path = ( - args.manifest_dir + f"cuts_{args.training_subset}_fixed.jsonl.gz" - ) - logging.info(f"Loading manifest from {manifest_path}") - cuts_manifest = load_manifest_lazy(manifest_path) - fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path) - logging.info(f"Fixed training manifest saved to {fixed_manifest_path}") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py deleted file mode 100755 index d8622842f..000000000 --- a/egs/wenetspeech/ASR/local/prepare_char.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input `lang_dir`, which should contain:: - - lang_dir/text, - - lang_dir/words.txt -and generates the following files in the directory `lang_dir`: - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -import re -from pathlib import Path -from typing import Dict, List - -import k2 -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: - """Check if all the given tokens are in token symbol table. - Args: - token_sym_table: - Token symbol table that contains all the valid tokens. - tokens: - A list of tokens. - Returns: - Return True if there is any token not in the token_sym_table, - otherwise False. - """ - for tok in tokens: - if tok not in token_sym_table: - return True - return False - - -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: - """Generate a lexicon from a word list and token_sym_table. - Args: - token_sym_table: - Token symbol table that mapping token to token ids. - words: - A list of strings representing words. - Returns: - Return a dict whose keys are words and values are the corresponding - tokens. - """ - lexicon = [] - for word in words: - chars = list(word.strip(" \t")) - if contain_oov(token_sym_table, chars): - continue - lexicon.append((word, chars)) - - # The OOV word is - lexicon.append(("", [""])) - return lexicon - - -def generate_tokens(text_file: str) -> Dict[str, int]: - """Generate tokens from the given text file. - Args: - text_file: - A file that contains text lines to generate tokens. - Returns: - Return a dict whose keys are tokens and values are token ids ranged - from 0 to len(keys) - 1. - """ - tokens: Dict[str, int] = dict() - tokens[""] = 0 - tokens[""] = 1 - tokens[""] = 2 - whitespace = re.compile(r"([ \t\r\n]+)") - with open(text_file, "r", encoding="utf-8") as f: - for line in f: - line = re.sub(whitespace, "", line) - tokens_list = list(line) - for token in tokens_list: - if token not in tokens: - tokens[token] = len(tokens) - return tokens - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang directory.") - args = parser.parse_args() - - lang_dir = Path(args.lang_dir) - text_file = lang_dir / "text" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - token_sym_table = generate_tokens(text_file) - - lexicon = generate_lexicon(token_sym_table, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py b/egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py deleted file mode 120000 index 2374cafdd..000000000 --- a/egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/prepare_char_lm_training_data.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py deleted file mode 100644 index 52da3d6dc..000000000 --- a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py +++ /dev/null @@ -1,143 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path - -import lhotse -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - fix_manifests, - validate_recordings_and_supervisions, -) - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--kaldi-dir", - type=str, - help="""The directory containing kaldi style manifest, namely wav.scp, text and segments. - """, - ) - - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bank bins. - """, - ) - - parser.add_argument( - "--output-dir", - type=str, - default="data/fbank", - help="""The directory where the lhotse manifests and features to write to. - """, - ) - - parser.add_argument( - "--dataset", - type=str, - help="""The name of dataset. - """, - ) - - parser.add_argument( - "--partition", - type=str, - help="""Could be something like train, valid, test and so on. - """, - ) - - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=True, - help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", - ) - - parser.add_argument( - "--num-jobs", type=int, default=50, help="The num of jobs to extract feature." - ) - - return parser.parse_args() - - -def prepare_cuts(args): - logging.info(f"Prepare cuts from {args.kaldi_dir}.") - recordings, supervisions, _ = lhotse.load_kaldi_data_dir(args.kaldi_dir, 16000) - recordings, supervisions = fix_manifests(recordings, supervisions) - validate_recordings_and_supervisions(recordings, supervisions) - cuts = CutSet.from_manifests(recordings=recordings, supervisions=supervisions) - return cuts - - -def compute_feature(args, cuts): - extractor = Fbank(FbankConfig(num_mel_bins=args.num_mel_bins)) - with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{args.dataset}_cuts_{args.partition}.jsonl.gz" - if (args.output_dir / cuts_filename).is_file(): - logging.info(f"{cuts_filename} already exists - skipping.") - return - logging.info(f"Processing {cuts_filename}") - - if "train" in args.partition: - if args.perturb_speed: - logging.info(f"Doing speed perturb") - cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) - cuts = cuts.compute_and_store_features( - extractor=extractor, - storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}", - # when an executor is specified, make more partitions - num_jobs=args.num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cuts.to_file(args.output_dir / cuts_filename) - - -def main(args): - args.kaldi_dir = Path(args.kaldi_dir) - args.output_dir = Path(args.output_dir) - cuts = prepare_cuts(args) - compute_feature(args, cuts) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - main(args) diff --git a/egs/wenetspeech/ASR/local/prepare_lang.py b/egs/wenetspeech/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/wenetspeech/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/prepare_pinyin.py b/egs/wenetspeech/ASR/local/prepare_pinyin.py deleted file mode 100755 index 112b50b79..000000000 --- a/egs/wenetspeech/ASR/local/prepare_pinyin.py +++ /dev/null @@ -1,276 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes as input `lang_dir`, which should contain:: - - lang_dir/words.txt -and generates the following files in the directory `lang_dir`: - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" -import argparse -import re -from pathlib import Path -from typing import Dict, List - -import k2 -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - -from icefall.utils import text_to_pinyin - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare lang for pinyin", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument("--lang-dir", type=str, help="The lang directory.") - - parser.add_argument( - "--token-type", - default="full_with_tone", - type=str, - help="""The type of pinyin, should be in: - full_with_tone: zhōng guó - full_no_tone: zhong guo - partial_with_tone: zh ōng g uó - partial_no_tone: zh ong g uo - """, - ) - - parser.add_argument( - "--pinyin-errors", - default="split", - type=str, - help="""How to handle characters that has no pinyin, - see `text_to_pinyin` in icefall/utils.py for details - """, - ) - - return parser - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: - """Check if all the given tokens are in token symbol table. - Args: - token_sym_table: - Token symbol table that contains all the valid tokens. - tokens: - A list of tokens. - Returns: - Return True if there is any token not in the token_sym_table, - otherwise False. - """ - for tok in tokens: - if tok not in token_sym_table: - return True - return False - - -def generate_lexicon( - args, token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: - """Generate a lexicon from a word list and token_sym_table. - Args: - token_sym_table: - Token symbol table that mapping token to token ids. - words: - A list of strings representing words. - Returns: - Return a dict whose keys are words and values are the corresponding - tokens. - """ - lexicon = [] - for word in words: - tokens = text_to_pinyin( - word.strip(), mode=args.token_type, errors=args.pinyin_errors - ) - if contain_oov(token_sym_table, tokens): - print(f"Word : {word} contains OOV token, skipping.") - continue - lexicon.append((word, tokens)) - - # The OOV word is - lexicon.append(("", [""])) - return lexicon - - -def generate_tokens(args, words: List[str]) -> Dict[str, int]: - """Generate tokens from the given word list. - Args: - words: - A list that contains words to generate tokens. - Returns: - Return a dict whose keys are tokens and values are token ids ranged - from 0 to len(keys) - 1. - """ - tokens: Dict[str, int] = dict() - tokens[""] = 0 - tokens[""] = 1 - tokens[""] = 2 - for word in words: - word = word.strip() - tokens_list = text_to_pinyin( - word, mode=args.token_type, errors=args.pinyin_errors - ) - for token in tokens_list: - if token not in tokens: - tokens[token] = len(tokens) - return tokens - - -def main(): - parser = get_parser() - args = parser.parse_args() - - lang_dir = Path(args.lang_dir) - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - token_sym_table = generate_tokens(args, words) - - lexicon = generate_lexicon(args, token_sym_table, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/prepare_words.py b/egs/wenetspeech/ASR/local/prepare_words.py deleted file mode 100644 index d5f833db1..000000000 --- a/egs/wenetspeech/ASR/local/prepare_words.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input words.txt without ids: - - words_no_ids.txt -and generates the new words.txt with related ids. - - words.txt -""" - - -import argparse -import logging - -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Prepare words.txt", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-file", - default="data/lang_char/words_no_ids.txt", - type=str, - help="the words file without ids for WenetSpeech", - ) - parser.add_argument( - "--output-file", - default="data/lang_char/words.txt", - type=str, - help="the words file with ids for WenetSpeech", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - - input_file = args.input_file - output_file = args.output_file - - f = open(input_file, "r", encoding="utf-8") - lines = f.readlines() - new_lines = [] - add_words = [" 0", "!SIL 1", " 2", " 3"] - new_lines.extend(add_words) - - logging.info("Starting reading the input file") - for i in tqdm(range(len(lines))): - x = lines[i] - idx = 4 + i - new_line = str(x.strip("\n")) + " " + str(idx) - new_lines.append(new_line) - - logging.info("Starting writing the words.txt") - f_out = open(output_file, "w", encoding="utf-8") - - # LG decoding needs below symbols. - id1, id2, id3 = ( - str(len(new_lines)), - str(len(new_lines) + 1), - str(len(new_lines) + 2), - ) - add_words = ["#0 " + id1, " " + id2, " " + id3] - new_lines.extend(add_words) - - for line in new_lines: - f_out.write(line) - f_out.write("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py deleted file mode 100755 index 5de3c23a9..000000000 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Johns Hopkins University (Piotr Żelasko) -# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import re -from pathlib import Path - -from lhotse import CutSet, SupervisionSegment -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall import setup_logger -from icefall.utils import str2bool - -# Similar text filtering and normalization procedure as in: -# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh - - -def normalize_text( - utt: str, - # punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), - punct_pattern=re.compile(r"<(PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), - whitespace_pattern=re.compile(r"\s\s+"), -) -> str: - return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) - - -def has_no_oov( - sup: SupervisionSegment, - oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"), -) -> bool: - return oov_pattern.search(sup.text) is None - - -def preprocess_wenet_speech(perturb_speed: bool = False): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - output_dir.mkdir(exist_ok=True) - - # Note: By default, we preprocess all sub-parts. - # You can delete those that you don't need. - # For instance, if you don't want to use the L subpart, just remove - # the line below containing "L" - dataset_parts = ( - "DEV", - "TEST_NET", - "TEST_MEETING", - "S", - "M", - "L", - ) - - logging.info("Loading manifest (may take 10 minutes)") - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - suffix="jsonl.gz", - prefix="wenetspeech", - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - for partition, m in manifests.items(): - logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" - if raw_cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping") - continue - - # Note this step makes the recipe different than LibriSpeech: - # We must filter out some utterances and remove punctuation - # to be consistent with Kaldi. - logging.info("Filtering OOV utterances from supervisions") - m["supervisions"] = m["supervisions"].filter(has_no_oov) - logging.info(f"Normalizing text in {partition}") - for sup in m["supervisions"]: - text = str(sup.text) - orig_text = text - sup.text = normalize_text(sup.text) - text = str(sup.text) - if len(orig_text) != len(text): - logging.info( - f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" - ) - - # Create long-recording cut manifests. - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - # Run data augmentation that needs to be done in the - # time domain. - if partition not in ["DEV", "TEST_NET", "TEST_MEETING"] and perturb_speed: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - logging.info(f"Saving to {raw_cuts_path}") - cut_set.to_file(raw_cuts_path) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - return parser.parse_args() - - -def main(): - setup_logger(log_filename="./log-preprocess-wenetspeech") - - args = get_args() - preprocess_wenet_speech(perturb_speed=args.perturb_speed) - logging.info("Done") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/sort_lm_training_data.py b/egs/wenetspeech/ASR/local/sort_lm_training_data.py deleted file mode 120000 index efef2c445..000000000 --- a/egs/wenetspeech/ASR/local/sort_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/text2segments.py b/egs/wenetspeech/ASR/local/text2segments.py deleted file mode 100644 index bdf5a3984..000000000 --- a/egs/wenetspeech/ASR/local/text2segments.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) -# 2022 Xiaomi Corp. (authors: Weiji Zhuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input "text", which refers to the transcript file for -WenetSpeech: - - text -and generates the output file text_word_segmentation which is implemented -with word segmenting: - - text_words_segmentation -""" - - -import argparse -from multiprocessing import Pool - -import jieba -import paddle -from tqdm import tqdm - -# In PaddlePaddle 2.x, dynamic graph mode is turned on by default, -# and 'data()' is only supported in static graph mode. So if you -# want to use this api, should call 'paddle.enable_static()' before -# this api to enter static graph mode. -# paddle.enable_static() -# paddle.disable_signal_handler() -jieba.enable_paddle() - - -def get_parser(): - parser = argparse.ArgumentParser( - description="Chinese Word Segmentation for text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--num-process", - "-n", - default=20, - type=int, - help="the number of processes", - ) - parser.add_argument( - "--input-file", - "-i", - default="data/lang_char/text", - type=str, - help="the input text file for WenetSpeech", - ) - parser.add_argument( - "--output-file", - "-o", - default="data/lang_char/text_words_segmentation", - type=str, - help="the text implemented with words segmenting for WenetSpeech", - ) - - return parser - - -def cut(lines): - if lines is not None: - cut_lines = jieba.cut(lines, use_paddle=True) - return [i for i in cut_lines] - else: - return None - - -def main(): - parser = get_parser() - args = parser.parse_args() - - num_process = args.num_process - input_file = args.input_file - output_file = args.output_file - # parallel mode does not support use_paddle - # jieba.enable_parallel(num_process) - - with open(input_file, "r", encoding="utf-8") as fr: - lines = fr.readlines() - - with Pool(processes=num_process) as p: - new_lines = list(tqdm(p.imap(cut, lines), total=len(lines))) - - with open(output_file, "w", encoding="utf-8") as fw: - for line in new_lines: - fw.write(" ".join(line) + "\n") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py deleted file mode 100755 index d1d237a68..000000000 --- a/egs/wenetspeech/ASR/local/text2token.py +++ /dev/null @@ -1,189 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) -# 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import codecs -import re -import sys -from typing import List - -from pypinyin import lazy_pinyin, pinyin - -is_python2 = sys.version_info[0] == 2 - - -def exist_or_not(i, match_pos): - start_pos = None - end_pos = None - for pos in match_pos: - if pos[0] <= i < pos[1]: - start_pos = pos[0] - end_pos = pos[1] - break - - return start_pos, end_pos - - -def get_parser(): - parser = argparse.ArgumentParser( - description="convert raw text to tokenized text", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--nchar", - "-n", - default=1, - type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", - ) - parser.add_argument( - "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" - ) - parser.add_argument("--space", default="", type=str, help="space symbol") - parser.add_argument( - "--non-lang-syms", - "-l", - default=None, - type=str, - help="list of non-linguistic symobles, e.g., etc.", - ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") - parser.add_argument( - "--trans_type", - "-t", - type=str, - default="char", - choices=["char", "pinyin", "lazy_pinyin"], - help="""Transcript type. char/pinyin/lazy_pinyin""", - ) - return parser - - -def token2id( - texts, token_table, token_type: str = "lazy_pinyin", oov: str = "" -) -> List[List[int]]: - """Convert token to id. - Args: - texts: - The input texts, it refers to the chinese text here. - token_table: - The token table is built based on "data/lang_xxx/token.txt" - token_type: - The type of token, such as "pinyin" and "lazy_pinyin". - oov: - Out of vocabulary token. When a word(token) in the transcript - does not exist in the token list, it is replaced with `oov`. - - Returns: - The list of ids for the input texts. - """ - if texts is None: - raise ValueError("texts can't be None!") - else: - oov_id = token_table[oov] - ids: List[List[int]] = [] - for text in texts: - chars_list = list(str(text)) - if token_type == "lazy_pinyin": - text = lazy_pinyin(chars_list) - sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text - ] - ids.append(sub_ids) - else: # token_type = "pinyin" - text = pinyin(chars_list) - sub_ids = [ - token_table[txt[0]] if txt[0] in token_table else oov_id - for txt in text - ] - ids.append(sub_ids) - return ids - - -def main(): - parser = get_parser() - args = parser.parse_args() - - rs = [] - if args.non_lang_syms is not None: - with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: - nls = [x.rstrip() for x in f.readlines()] - rs = [re.compile(re.escape(x)) for x in nls] - - if args.text: - f = codecs.open(args.text, encoding="utf-8") - else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) - - sys.stdout = codecs.getwriter("utf-8")( - sys.stdout if is_python2 else sys.stdout.buffer - ) - line = f.readline() - n = args.nchar - while line: - x = line.split() - print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols :]) # noqa E203 - - # get all matched positions - match_pos = [] - for r in rs: - i = 0 - while i >= 0: - m = r.search(a, i) - if m: - match_pos.append([m.start(), m.end()]) - i = m.end() - else: - break - if len(match_pos) > 0: - chars = [] - i = 0 - while i < len(a): - start_pos, end_pos = exist_or_not(i, match_pos) - if start_pos is not None: - chars.append(a[start_pos:end_pos]) - i = end_pos - else: - chars.append(a[i]) - i += 1 - a = chars - - if args.trans_type == "pinyin": - a = pinyin(list(str(a))) - a = [one[0] for one in a] - - if args.trans_type == "lazy_pinyin": - a = lazy_pinyin(list(str(a))) - - a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203 - - a_flat = [] - for z in a: - a_flat.append("".join(z)) - - a_chars = [z.replace(" ", args.space) for z in a_flat] - - print("".join(a_chars)) - line = f.readline() - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh deleted file mode 100755 index 74f213707..000000000 --- a/egs/wenetspeech/ASR/prepare.sh +++ /dev/null @@ -1,427 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=0 -stop_stage=100 - -# Split L subset to this number of pieces -# This is to avoid OOM during feature extraction. -num_splits=1000 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/WenetSpeech -# You can find audio, WenetSpeech.json inside it. -# You can apply for the download credentials by following -# https://github.com/wenet-e2e/WenetSpeech#download -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download -lang_char_dir=data/lang_char - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - [ ! -e $dl_dir/WenetSpeech ] && mkdir -p $dl_dir/WenetSpeech - - # If you have pre-downloaded it to /path/to/WenetSpeech, - # you can create a symlink - # - # ln -sfv /path/to/WenetSpeech $dl_dir/WenetSpeech - # - if [ ! -d $dl_dir/WenetSpeech/wenet_speech ] && [ ! -f $dl_dir/WenetSpeech/metadata/v1.list ]; then - log "Stage 0: You should download WenetSpeech first" - exit 1; - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - #ln -sfv /path/to/musan $dl_dir/musan - - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare WenetSpeech manifest" - # We assume that you have downloaded the WenetSpeech corpus - # to $dl_dir/WenetSpeech - mkdir -p data/manifests - lhotse prepare wenet-speech $dl_dir/WenetSpeech data/manifests -j $nj -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Preprocess WenetSpeech manifest" - if [ ! -f data/fbank/.preprocess_complete ]; then - python3 ./local/preprocess_wenetspeech.py --perturb-speed True - touch data/fbank/.preprocess_complete - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute features for DEV and TEST subsets of WenetSpeech (may take 2 minutes)" - python3 ./local/compute_fbank_wenetspeech_dev_test.py -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Split S subset into ${num_splits} pieces" - split_dir=data/fbank/S_split_${num_splits} - if [ ! -f $split_dir/.split_completed ]; then - lhotse split $num_splits ./data/fbank/cuts_S_raw.jsonl.gz $split_dir - touch $split_dir/.split_completed - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Split M subset into ${num_splits} piece" - split_dir=data/fbank/M_split_${num_splits} - if [ ! -f $split_dir/.split_completed ]; then - lhotse split $num_splits ./data/fbank/cuts_M_raw.jsonl.gz $split_dir - touch $split_dir/.split_completed - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Split L subset into ${num_splits} pieces" - split_dir=data/fbank/L_split_${num_splits} - if [ ! -f $split_dir/.split_completed ]; then - lhotse split $num_splits ./data/fbank/cuts_L_raw.jsonl.gz $split_dir - touch $split_dir/.split_completed - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compute features for S" - python3 ./local/compute_fbank_wenetspeech_splits.py \ - --training-subset S \ - --num-workers 20 \ - --batch-duration 600 \ - --start 0 \ - --num-splits $num_splits -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compute features for M" - python3 ./local/compute_fbank_wenetspeech_splits.py \ - --training-subset M \ - --num-workers 20 \ - --batch-duration 600 \ - --start 0 \ - --num-splits $num_splits -fi - -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Compute features for L" - python3 ./local/compute_fbank_wenetspeech_splits.py \ - --training-subset L \ - --num-workers 20 \ - --batch-duration 600 \ - --start 0 \ - --num-splits $num_splits -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Combine features for S" - if [ ! -f data/fbank/cuts_S.jsonl.gz ]; then - pieces=$(find data/fbank/S_split_${num_splits} -name "cuts_S.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_S.jsonl.gz - fi -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Combine features for M" - if [ ! -f data/fbank/cuts_M.jsonl.gz ]; then - pieces=$(find data/fbank/M_split_${num_splits} -name "cuts_M.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_M.jsonl.gz - fi -fi - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Combine features for L" - if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then - pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_L.jsonl.gz - fi -fi - -whisper_mel_bins=80 -if [ $stage -le 129 ] && [ $stop_stage -ge 129 ]; then - log "Stage 129: compute whisper fbank for dev and test sets" - python3 ./local/compute_fbank_wenetspeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true -fi -if [ $stage -le 130 ] && [ $stop_stage -ge 130 ]; then - log "Stage 130: Comute features for whisper training set" - - split_dir=data/fbank/L_split_${num_splits} - if [ ! -f $split_dir/.split_completed ]; then - lhotse split $num_splits ./data/fbank/cuts_L_raw.jsonl.gz $split_dir - touch $split_dir/.split_completed - fi - - python3 ./local/compute_fbank_wenetspeech_splits.py \ - --training-subset L \ - --num-workers 8 \ - --batch-duration 1600 \ - --start 0 \ - --num-mel-bins ${whisper_mel_bins} --whisper-fbank true \ - --num-splits $num_splits - - if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then - pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_L.jsonl.gz - fi -fi - -if [ $stage -le 131 ] && [ $stop_stage -ge 131 ]; then - log "Stage 131: concat feats into train set" - if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then - pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_L.jsonl.gz - fi -fi - - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Compute fbank for musan" - mkdir -p data/fbank - ./local/compute_fbank_musan.py -fi - -if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then - log "Stage 15: Prepare char based lang" - mkdir -p $lang_char_dir - - if ! which jq; then - echo "This script is intended to be used with jq but you have not installed jq - Note: in Linux, you can install jq with the following command: - 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - 2. chmod +x ./jq - 3. cp jq /usr/bin" && exit 1 - fi - if [ ! -f $lang_char_dir/text ] || [ ! -s $lang_char_dir/text ]; then - log "Prepare text." - gunzip -c data/manifests/wenetspeech_supervisions_L.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text - fi - - # The implementation of chinese word segmentation for text, - # and it will take about 15 minutes. - if [ ! -f $lang_char_dir/text_words_segmentation ]; then - python3 ./local/text2segments.py \ - --num-process $nj \ - --input-file $lang_char_dir/text \ - --output-file $lang_char_dir/text_words_segmentation - fi - - cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt - - if [ ! -f $lang_char_dir/words.txt ]; then - python3 ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt - fi -fi - -if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then - log "Stage 16: Prepare char based L_disambig.pt" - if [ ! -f data/lang_char/L_disambig.pt ]; then - python3 ./local/prepare_char.py \ - --lang-dir data/lang_char - fi -fi - -# If you don't want to use LG for decoding, the following steps are not necessary. -if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then - log "Stage 17: Prepare G" - # It will take about 20 minutes. - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then - python3 ./shared/make_kn_lm.py \ - -ngram-order 3 \ - -text $lang_char_dir/text_words_segmentation \ - -lm $lang_char_dir/3-gram.unpruned.arpa - fi - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building LG - python3 -m kaldilm \ - --read-symbol-table="$lang_char_dir/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt - fi -fi - -if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then - log "Stage 18: Compile LG" - python ./local/compile_lg.py --lang-dir $lang_char_dir -fi - -# prepare RNNLM data -if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then - log "Stage 19: Prepare LM training data" - - log "Processing char based data" - text_out_dir=data/lm_char - - mkdir -p $text_out_dir - - log "Genearating training text data" - - if [ ! -f $text_out_dir/lm_data.pt ]; then - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $lang_char_dir/text_words_segmentation \ - --lm-archive $text_out_dir/lm_data.pt - fi - - log "Generating DEV text data" - # prepare validation text data - if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then - valid_text=${text_out_dir}/ - - gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $text_out_dir/valid_text - - python3 ./local/text2segments.py \ - --num-process $nj \ - --input-file $text_out_dir/valid_text \ - --output-file $text_out_dir/valid_text_words_segmentation - fi - - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $text_out_dir/valid_text_words_segmentation \ - --lm-archive $text_out_dir/lm_data_valid.pt - - # prepare TEST text data - if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then - log "Prepare text for test set." - for test_set in TEST_MEETING TEST_NET; do - gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \ - | jq '.text' | sed 's/"//g' \ - | ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text - - python3 ./local/text2segments.py \ - --num-process $nj \ - --input-file $text_out_dir/${test_set}_text \ - --output-file $text_out_dir/${test_set}_text_words_segmentation - done - - cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation - fi - - ./local/prepare_char_lm_training_data.py \ - --lang-char data/lang_char \ - --lm-data $text_out_dir/test_text_words_segmentation \ - --lm-archive $text_out_dir/lm_data_test.pt - -fi - -# sort RNNLM data -if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then - text_out_dir=data/lm_char - - log "Sort lm data" - - ./local/sort_lm_training_data.py \ - --in-lm-data $text_out_dir/lm_data.pt \ - --out-lm-data $text_out_dir/sorted_lm_data.pt \ - --out-statistics $text_out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $text_out_dir/lm_data_valid.pt \ - --out-lm-data $text_out_dir/sorted_lm_data-valid.pt \ - --out-statistics $text_out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $text_out_dir/lm_data_test.pt \ - --out-lm-data $text_out_dir/sorted_lm_data-test.pt \ - --out-statistics $text_out_dir/statistics-test.txt -fi - -export CUDA_VISIBLE_DEVICES="0,1" - -if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then - log "Stage 21: Train RNN LM model" - python ../../../icefall/rnn_lm/train.py \ - --start-epoch 0 \ - --world-size 2 \ - --num-epochs 20 \ - --use-fp16 0 \ - --embedding-dim 2048 \ - --hidden-dim 2048 \ - --num-layers 2 \ - --batch-size 400 \ - --exp-dir rnnlm_char/exp \ - --lm-data data/lm_char/sorted_lm_data.pt \ - --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ - --vocab-size 5537 \ - --master-port 12340 -fi - -if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then - log "Stage 22: Prepare pinyin based lang" - for token in full_with_tone partial_with_tone; do - lang_dir=data/lang_${token} - if [ ! -f $lang_dir/tokens.txt ]; then - cp data/lang_char/words.txt $lang_dir/words.txt - python local/prepare_pinyin.py \ - --token-type $token \ - --lang-dir $lang_dir - fi - python ./local/compile_lg.py --lang-dir $lang_dir - done -fi - -if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then - log "Stage 23: Modify transcript according to fixed results" - # See https://github.com/wenet-e2e/WenetSpeech/discussions/54 - wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix -O data/fbank/text.fix - python local/fix_manifest.py \ - --fixed-transcript-path data/fbank/text.fix \ - --training-subset L -fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py deleted file mode 120000 index f7321272b..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py deleted file mode 100644 index 8b35187b1..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - load_manifest, - load_manifest_lazy, - set_caching_enabled, -) -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class WenetSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--training-subset", - type=str, - default="L", - help="The training subset for using", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - - valid_dl = DataLoader( - validate, - batch_size=None, - sampler=valid_sampler, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.manifest_dir / f"cuts_{self.args.training_subset}_fixed.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV_fixed.jsonl.gz") - - @lru_cache() - def test_net_cuts(self) -> List[CutSet]: - logging.info("About to get TEST_NET cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") - - @lru_cache() - def test_meeting_cuts(self) -> List[CutSet]: - logging.info("About to get TEST_MEETING cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/conformer.py deleted file mode 120000 index a65957180..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py deleted file mode 100755 index 2bafe25d6..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1,689 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When training with the L subset, usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (1best) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 - -(4) fast beam search (nbest) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(5) fast beam search (nbest oracle WER) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (with LG) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import get_params, get_transducer_model - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--batch", - type=int, - default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to - specify `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.35, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - beam=params.beam_size, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list(str(text)) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if params.decoding_method == "fast_beam_search_nbest_LG": - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if ( - params.decoding_method == "fast_beam_search_nbest" - or params.decoding_method == "fast_beam_search_nbest_oracle" - ): - params.suffix += f"-nbest-scale-{params.nbest_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - elif params.batch is not None: - filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt" - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints([filenames], device=device)) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lg_filename = params.lang_dir + "/LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) - - dev_cuts = wenetspeech.valid_cuts() - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py deleted file mode 100755 index 2e644ec2f..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py +++ /dev/null @@ -1,547 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from aishell import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from finetune import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - token_table: - It maps token ID to a string. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - dev_cuts = aishell.valid_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - dev_dl = aishell.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py deleted file mode 100755 index 8aea79fe3..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ /dev/null @@ -1,513 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=icefall_asr_wenetspeech_pruned_transducer_stateless2 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_10_avg_2.pt" - -cd exp -ln -s pretrained_epoch_10_avg_2.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless2/export-onnx.py \ - --tokens $repo/data/lang_char/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import num_tokens, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Conformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Conformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Conformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - encoder_out, encoder_out_lens = self.encoder(x, x_lens) - - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "conformer", - "version": "1", - "model_author": "k2-fsa", - "comment": "stateless5", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py deleted file mode 100755 index 2f6ef488e..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Please refer to -https://k2-fsa.github.io/sherpa/python/offline_asr/conformer/index.html -for how to use `cpu_jit.pt` for speech recognition. - -It will also generate 3 other files: `encoder_jit_script.pt`, -`decoder_jit_script.pt`, and `joiner_jit_script.pt`. Check ./jit_pretrained.py -for how to use them. - -(2) Export to torchscript model using torch.jit.trace() - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --jit-trace 1 - -It will generate the following files: - - encoder_jit_trace.pt - - decoder_jit_trace.pt - - joiner_jit_trace.pt - -Check ./jit_pretrained.py for usage. - -(3) Export `model.state_dict()` - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 10 \ - --avg 2 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless2/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/wenetspeech/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 100 \ - --lang-dir data/lang_char - -You can find pretrained models at -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate 4 files: - - encoder_jit_script.pt - - decoder_jit_script.pt - - joiner_jit_script.pt - - cpu_jit.pt (which combines the above 3 files) - - Check ./jit_pretrained.py for how to use xxx_jit_script.pt - """, - ) - - parser.add_argument( - "--jit-trace", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.trace. - It will generate 3 files: - - encoder_jit_trace.pt - - decoder_jit_trace.pt - - joiner_jit_trace.pt - - Check ./jit_pretrained.py for how to use them. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def export_encoder_model_jit_script( - encoder_model: nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.script() - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - script_model = torch.jit.script(encoder_model) - script_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_script( - decoder_model: nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.script() - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - script_model = torch.jit.script(decoder_model) - script_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_script( - joiner_model: nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - """ - script_model = torch.jit.script(joiner_model) - script_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -def export_encoder_model_jit_trace( - encoder_model: nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - traced_model = torch.jit.trace(encoder_model, (x, x_lens)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to("cpu") - model.eval() - - if params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.script") - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - - # Also export encoder/decoder/joiner separately - encoder_filename = params.exp_dir / "encoder_jit_script.pt" - export_encoder_model_jit_script(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_script.pt" - export_decoder_model_jit_script(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_script.pt" - export_joiner_model_jit_script(model.joiner, joiner_filename) - elif params.jit_trace is True: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - encoder_filename = params.exp_dir / "encoder_jit_trace.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_trace.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_trace.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py deleted file mode 100755 index c34f1593d..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ /dev/null @@ -1,1054 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless2/finetune.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --do-finetune 1 \ - --max-duration 100 - -""" - - -import argparse -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from aishell import AishellAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_finetune_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--do-finetune", type=str2bool, default=False) - - parser.add_argument( - "--init-modules", - type=str, - default=None, - help=""" - Modules to be initialized. It matches all parameters starting with - a specific key. The keys are given with Comma separated. If None, - all modules will be initialised. For example, if you only want to - initialise all parameters staring with "encoder", use "encoder"; - if you want to initialise parameters starting with encoder or decoder, - use "encoder,joiner". - """, - ) - - parser.add_argument( - "--finetune-ckpt", - type=str, - default=None, - help="Fine-tuning from which checkpoint (a path to a .pt file)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.0001, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=100000, - help="""Number of steps that affects how rapidly the learning rate - decreases. During fine-tuning, we set this very large so that the - learning rate slowly decays with number of batches. You may tune - its value by yourself. - """, - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=100, - help="""Number of epochs that affects how rapidly the learning rate - decreases. During fine-tuning, we set this very large so that the - learning rate slowly decays with number of batches. You may tune - its value by yourself. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--valid-interval", - type=int, - default=3000, - help="""When training_subset is L, set the valid_interval to 3000. - When training_subset is M, set the valid_interval to 1000. - When training_subset is S, set the valid_interval to 400. - """, - ) - - parser.add_argument( - "--model-warm-step", - type=int, - default=3000, - help="""When training_subset is L, set the model_warm_step to 3000. - When training_subset is M, set the model_warm_step to 500. - When training_subset is S, set the model_warm_step to 100. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - Explanation of options saved in `params`: - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - best_train_epoch: It is the epoch that has the best training loss. - - best_valid_epoch: It is the epoch that has the best validation loss. - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - feature_dim: The model input dim. It has to match the one used - in computing features. - - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def load_model_params( - ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True -): - """Load model params from checkpoint - - Args: - ckpt (str): Path to the checkpoint - model (nn.Module): model to be loaded - - """ - logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") - - # if module list is empty, load the whole model from ckpt - if not init_modules: - if next(iter(checkpoint["model"])).startswith("module."): - logging.info("Loading checkpoint saved by DDP") - - dst_state_dict = model.state_dict() - src_state_dict = checkpoint["model"] - for key in dst_state_dict.keys(): - src_key = "{}.{}".format("module", key) - dst_state_dict[key] = src_state_dict.pop(src_key) - assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=strict) - else: - model.load_state_dict(checkpoint["model"], strict=strict) - else: - src_state_dict = checkpoint["model"] - dst_state_dict = model.state_dict() - for module in init_modules: - logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [ - k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") - ] - dst_keys = [ - k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") - ] - assert set(src_keys) == set(dst_keys) # two sets should match exactly - for key in src_keys: - dst_state_dict[key] = src_state_dict.pop(key) - - model.load_state_dict(dst_state_dict, strict=strict) - - return None - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - - y = graph_compiler.texts_to_ids(texts) - if isinstance(y, list): - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # load model parameters for model fine-tuning - if params.do_finetune: - modules = params.init_modules.split(",") if params.init_modules else None - checkpoints = load_model_params( - ckpt=params.finetune_ckpt, model=model, init_modules=modules - ) - else: - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments( - parser - ) # you may replace this with your own dataset - add_finetune_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py deleted file mode 100755 index aee1a2175..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py +++ /dev/null @@ -1,337 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, either exported by `torch.jit.trace()` -or by `torch.jit.script()`, and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --jit-trace 1 - -or - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --jit 1 - -Usage of this script: - -./pruned_transducer_stateless2/jit_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_trace.pt \ - --decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_trace.pt \ - --joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_trace.pt \ - --tokens data/lang_char/tokens.txt \ - /path/to/foo.wav \ - /path/to/bar.wav - -or - -./pruned_transducer_stateless2/jit_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_script.pt \ - --decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_script.pt \ - --joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_script.pt \ - --tokens data/lang_char/tokens.txt \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can find pretrained models at -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder torchscript model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder torchscript model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner torchscript model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - context_size: int, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - decoder: - The decoder model. - joiner: - The joiner model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - encoder = torch.jit.load(args.encoder_model_filename) - decoder = torch.jit.load(args.decoder_model_filename) - joiner = torch.jit.load(args.joiner_model_filename) - - encoder.eval() - decoder.eval() - joiner.eval() - - encoder.to(device) - decoder.to(device) - joiner.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = encoder( - x=features, - x_lens=feature_lengths, - ) - - hyps = greedy_search( - decoder=decoder, - joiner=joiner, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context_size=args.context_size, - ) - symbol_table = k2.SymbolTable.from_file(args.tokens) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = "".join([symbol_table[i] for i in hyp]) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py deleted file mode 120000 index b82e115fc..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/model.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py deleted file mode 100755 index 2d46eede1..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ /dev/null @@ -1,303 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. - -Usage: - -./pruned_transducer_stateless2/onnx_check.py \ - --jit-filename ./t/cpu_jit.pt \ - --onnx-encoder-filename ./t/encoder.onnx \ - --onnx-decoder-filename ./t/decoder.onnx \ - --onnx-joiner-filename ./t/joiner.onnx \ - --onnx-joiner-encoder-proj-filename ./t/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename ./t/joiner_decoder_proj.onnx - -You can generate cpu_jit.pt, encoder.onnx, decoder.onnx, and other -xxx.onnx files using ./export.py - -We provide pretrained models at: -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp -""" - -import argparse -import logging - -from icefall import is_module_available - -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - -import onnxruntime as ort -import torch - -ort.set_default_logger_severity(3) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model exported by torch.jit.script", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - parser.add_argument( - "--onnx-joiner-encoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner encoder projection model", - ) - - parser.add_argument( - "--onnx-joiner-decoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner decoder projection model", - ) - - return parser - - -def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, -): - inputs = encoder_session.get_inputs() - outputs = encoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", "T", 80] - assert inputs[1].shape == ["N"] - - for N in [1, 5]: - for T in [12, 25]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T - - encoder_inputs = { - input_names[0]: x.numpy(), - input_names[1]: x_lens.numpy(), - } - encoder_out, encoder_out_lens = encoder_session.run( - output_names, - encoder_inputs, - ) - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max(), - encoder_out.shape, - torch_encoder_out.shape, - ) - - -def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, -): - inputs = decoder_session.get_inputs() - outputs = decoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", 2] - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - output_names, - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() - ) - - -def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, - joiner_encoder_proj_session: ort.InferenceSession, - joiner_decoder_proj_session: ort.InferenceSession, -): - joiner_inputs = joiner_session.get_inputs() - joiner_outputs = joiner_session.get_outputs() - joiner_input_names = [n.name for n in joiner_inputs] - joiner_output_names = [n.name for n in joiner_outputs] - - assert joiner_inputs[0].shape == ["N", 512] - assert joiner_inputs[1].shape == ["N", 512] - - joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - encoder_proj_input_name = joiner_encoder_proj_inputs[0].name - - assert joiner_encoder_proj_inputs[0].shape == ["N", 512] - - joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() - encoder_proj_output_name = joiner_encoder_proj_outputs[0].name - - joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - decoder_proj_input_name = joiner_decoder_proj_inputs[0].name - - assert joiner_decoder_proj_inputs[0].shape == ["N", 512] - - joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() - decoder_proj_output_name = joiner_decoder_proj_outputs[0].name - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) - - projected_encoder_out = torch.rand(N, 512) - projected_decoder_out = torch.rand(N, 512) - - joiner_inputs = { - joiner_input_names[0]: projected_encoder_out.numpy(), - joiner_input_names[1]: projected_decoder_out.numpy(), - } - joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - projected_encoder_out, - projected_decoder_out, - project_input=False, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() - ) - - # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} - joiner_encoder_proj_out = joiner_encoder_proj_session.run( - [encoder_proj_output_name], joiner_encoder_proj_inputs - )[0] - joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - - torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) - assert torch.allclose( - joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) - - # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} - joiner_decoder_proj_out = joiner_decoder_proj_session.run( - [decoder_proj_output_name], joiner_decoder_proj_inputs - )[0] - joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - - torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) - assert torch.allclose( - joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - model = torch.jit.load(args.jit_filename) - - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 - - logging.info("Test encoder") - encoder_session = ort.InferenceSession( - args.onnx_encoder_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - test_encoder(model, encoder_session) - - logging.info("Test decoder") - decoder_session = ort.InferenceSession( - args.onnx_decoder_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - test_decoder(model, decoder_session) - - logging.info("Test joiner") - joiner_session = ort.InferenceSession( - args.onnx_joiner_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - joiner_encoder_proj_session = ort.InferenceSession( - args.onnx_joiner_encoder_proj_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - joiner_decoder_proj_session = ort.InferenceSession( - args.onnx_joiner_decoder_proj_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - test_joiner( - model, - joiner_session, - joiner_encoder_proj_session, - joiner_decoder_proj_session, - ) - logging.info("Finished checking ONNX models") - - -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py deleted file mode 120000 index f1bfbee49..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless5/onnx_pretrained.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/optim.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py deleted file mode 100755 index 642de72d7..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ /dev/null @@ -1,335 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method greedy_search \ - --max-sym-per-frame 1 \ - /path/to/foo.wav \ - /path/to/bar.wav -(2) modified beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav -(3) fast beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - /path/to/foo.wav \ - /path/to/bar.wav -You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. -Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by -./pruned_transducer_stateless2/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""Used only when --decoding-method is beam_search - and modified_beam_search """, - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --decoding-method is greedy_search. - """, - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - hyps = [] - msg = f"Using {params.decoding_method}" - logging.info(msg) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = "".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling_converter.py deleted file mode 120000 index db93d155b..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py deleted file mode 100644 index 49977e01b..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ /dev/null @@ -1,1066 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -For training with the L subset: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 15 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --training-subset L - -# For mix precision training: - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 10 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --use-fp16 True \ - --training-subset L - -For training with the M subset: - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 29 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 1000 \ - --model-warm-step 500 \ - --save-every-n 1000 \ - --training-subset M - -For training with the S subset: - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 29 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 400 \ - --model-warm-step 100 \ - --save-every-n 1000 \ - --training-subset S -""" - -import argparse -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--valid-interval", - type=int, - default=3000, - help="""When training_subset is L, set the valid_interval to 3000. - When training_subset is M, set the valid_interval to 1000. - When training_subset is S, set the valid_interval to 400. - """, - ) - - parser.add_argument( - "--model-warm-step", - type=int, - default=3000, - help="""When training_subset is L, set the model_warm_step to 3000. - When training_subset is M, set the model_warm_step to 500. - When training_subset is S, set the model_warm_step to 100. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - Explanation of options saved in `params`: - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - best_train_epoch: It is the epoch that has the best training loss. - - best_valid_epoch: It is the epoch that has the best validation loss. - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - feature_dim: The model input dim. It has to match the one used - in computing features. - - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - - y = graph_compiler.texts_to_ids(texts) - if isinstance(y, list): - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - wenetspeech = WenetSpeechAsrDataModule(args) - - train_cuts = wenetspeech.train_cuts() - valid_cuts = wenetspeech.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 10 seconds - # - # Caution: There is a reason to select 10.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 10.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = c.supervisions[0].text.replace(" ", "") - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - valid_dl = wenetspeech.valid_dataloaders(valid_cuts) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = wenetspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - if not params.print_diagnostics and params.start_batch == 0: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - texts = batch["supervisions"]["text"] - num_tokens = sum(len(i) for i in texts) - - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params) - raise - - -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/__init__.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/beam_search.py deleted file mode 120000 index 02d01b343..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/beam_search.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py deleted file mode 100644 index 23a877b2f..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py +++ /dev/null @@ -1,1531 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -import warnings -from typing import List, Optional, Tuple - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv1d, - ScaledConv2d, - ScaledLinear, -) -from torch import Tensor, nn - -from icefall.utils import make_pad_mask, subsequent_chunk_mask - - -class Conformer(EncoderInterface): - """ - Args: - num_features (int): Number of input features - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension, also the output dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - dropout (float): dropout rate - layer_dropout (float): layer-dropout rate. - cnn_module_kernel (int): Kernel size of convolution module - vgg_frontend (bool): whether to use vgg frontend. - dynamic_chunk_training (bool): whether to use dynamic chunk training, if - you want to train a streaming model, this is expected to be True. - When setting True, it will use a masking strategy to make the attention - see only limited left and right context. - short_chunk_threshold (float): a threshold to determinize the chunk size - to be used in masking training, if the randomly generated chunk size - is greater than ``max_len * short_chunk_threshold`` (max_len is the - max sequence length of current batch) then it will use - full context in training (i.e. with chunk size equals to max_len). - This will be used only when dynamic_chunk_training is True. - short_chunk_size (int): see docs above, if the randomly generated chunk - size equals to or less than ``max_len * short_chunk_threshold``, the - chunk size will be sampled uniformly from 1 to short_chunk_size. - This also will be used only when dynamic_chunk_training is True. - num_left_chunks (int): the left context (in chunks) attention can see, the - chunk size is decided by short_chunk_threshold and short_chunk_size. - A minus value means seeing full left context. - This also will be used only when dynamic_chunk_training is True. - causal (bool): Whether to use causal convolution in conformer encoder - layer. This MUST be True when using dynamic_chunk_training. - """ - - def __init__( - self, - num_features: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - dropout: float = 0.1, - layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, - dynamic_chunk_training: bool = False, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 25, - num_left_chunks: int = -1, - causal: bool = False, - ) -> None: - super(Conformer, self).__init__() - - self.num_features = num_features - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_layers = num_encoder_layers - self.d_model = d_model - self.cnn_module_kernel = cnn_module_kernel - self.causal = causal - self.dynamic_chunk_training = dynamic_chunk_training - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - self.num_left_chunks = num_left_chunks - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - layer_dropout, - cnn_module_kernel, - causal, - ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self._init_state: List[torch.Tensor] = [torch.empty(0)] - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, d_model) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # Caution: We assume the subsampling factor is 4! - - # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 1) >> 1) - 1) >> 1 - - assert x.size(0) == lengths.max().item() - - src_key_padding_mask = make_pad_mask(lengths) - - if self.dynamic_chunk_training: - assert ( - self.causal - ), "Causal convolution is required for streaming conformer." - max_len = x.size(0) - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - chunk_size = max_len - else: - chunk_size = chunk_size % self.short_chunk_size + 1 - - mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - x = self.encoder( - x, - pos_emb, - mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - else: - x = self.encoder( - x, - pos_emb, - mask=None, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, lengths - - @torch.jit.export - def get_init_state( - self, left_context: int, device: torch.device - ) -> List[torch.Tensor]: - """Return the initial cache state of the model. - Args: - left_context: The left context size (in frames after subsampling). - Returns: - Return the initial state of the model, it is a list containing two - tensors, the first one is the cache for attentions which has a shape - of (num_encoder_layers, left_context, encoder_dim), the second one - is the cache of conv_modules which has a shape of - (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). - NOTE: the returned tensors are on the given device. - """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: - # Note: It is OK to share the init state as it is - # not going to be modified by the model - return self._init_state - - init_states: List[torch.Tensor] = [ - torch.zeros( - ( - self.encoder_layers, - left_context, - self.d_model, - ), - device=device, - ), - torch.zeros( - ( - self.encoder_layers, - self.cnn_module_kernel - 1, - self.d_model, - ), - device=device, - ), - ] - - self._init_state = init_states - - return init_states - - @torch.jit.export - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: Optional[List[Tensor]] = None, - processed_lens: Optional[Tensor] = None, - left_context: int = 64, - right_context: int = 4, - chunk_size: int = 16, - simulate_streaming: bool = False, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (encoder_layers, left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: states will be modified in this function. - processed_lens: - How many frames (after subsampling) have been processed for each sequence. - left_context: - How many previous frames the attention can see in current chunk. - Note: It's not that each individual frame has `left_context` frames - of left context, some have more. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - of right context, some have more. - chunk_size: - The chunk size for decoding, this will be used to simulate streaming - decoding using masking. - simulate_streaming: - If setting True, it will use a masking strategy to simulate streaming - fashion (i.e. every chunk data only see limited left context and - right context). The whole sequence is supposed to be send at a time - When using simulate_streaming. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - Returns: - Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. - - decode_states, the updated states including the information - of current chunk. - """ - - # x: [N, T, C] - # Caution: We assume the subsampling factor is 4! - - # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 1) >> 1) - 1) >> 1 - - if not simulate_streaming: - assert states is not None - assert processed_lens is not None - assert ( - len(states) == 2 - and states[0].shape - == (self.encoder_layers, left_context, x.size(0), self.d_model) - and states[1].shape - == ( - self.encoder_layers, - self.cnn_module_kernel - 1, - x.size(0), - self.d_model, - ) - ), f"""The length of states MUST be equal to 2, and the shape of - first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, - given {states[0].shape}. the shape of second element should be - {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, - given {states[1].shape}.""" - - lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output - - src_key_padding_mask = make_pad_mask(lengths) - - processed_mask = torch.arange(left_context, device=x.device).expand( - x.size(0), left_context - ) - processed_lens = processed_lens.view(x.size(0), 1) - processed_mask = (processed_lens <= processed_mask).flip(1) - - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) - - embed = self.encoder_embed(x) - - # cut off 1 frame on each size of embed as they see the padding - # value which causes a training and decoding mismatch. - embed = embed[:, 1:-1, :] - - embed, pos_enc = self.encoder_pos(embed, left_context) - embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - - x, states = self.encoder.chunk_forward( - embed, - pos_enc, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - states=states, - left_context=left_context, - right_context=right_context, - ) # (T, B, F) - if right_context > 0: - x = x[0:-right_context, ...] - lengths -= right_context - else: - assert states is None - states = [] # just to make torch.script.jit happy - # this branch simulates streaming decoding using mask as we are - # using in training time. - src_key_padding_mask = make_pad_mask(lengths) - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - assert x.size(0) == lengths.max().item() - - num_left_chunks = -1 - if left_context >= 0: - assert left_context % chunk_size == 0 - num_left_chunks = left_context // chunk_size - - mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=num_left_chunks, - device=x.device, - ) - x = self.encoder( - x, - pos_emb, - mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return x, lengths, states - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - causal (bool): Whether to use causal convolution in conformer encoder - layer. This MUST be True when using dynamic_chunk_training and streaming decoding. - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, - causal: bool = False, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - - self.layer_dropout = layer_dropout - - self.d_model = d_model - - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.feed_forward_macaron = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) - - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - src_orig = src - - warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else 0.1 - ) - else: - alpha = 1.0 - - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - - # multi-headed self-attention module - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - - src = src + self.dropout(src_att) - - # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.dropout(conv) - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src - - @torch.jit.export - def chunk_forward( - self, - src: Tensor, - pos_emb: Tensor, - states: List[Tensor], - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[Tensor, List[Tensor]]: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (cnn_module_kernel-1, batch, conv_dim). - Note: states will be modified in this function. - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - left_context: - How many previous frames the attention can see in current chunk. - Note: It's not that each individual frame has `left_context` frames - of left context, some have more. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - of right context, some have more. - Shape: - src: (S, N, E). - pos_emb: (N, 2*(S+left_context)-1, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - assert not self.training - assert len(states) == 2 - assert states[0].shape == (left_context, src.size(1), src.size(2)) - - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - - # We put the attention cache this level (i.e. before linear transformation) - # to save memory consumption, when decoding in streaming fashion, the - # batch size would be thousands (for 32GB machine), if we cache key & val - # separately, it needs extra several GB memory. - # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val - # separately) if needed. - key = torch.cat([states[0], src], dim=0) - val = key - if right_context > 0: - states[0] = key[ - -(left_context + right_context) : -right_context, ... # noqa - ] - else: - states[0] = key[-left_context:, ...] - - # multi-headed self-attention module - src_att = self.self_attn( - src, - key, - val, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - left_context=left_context, - )[0] - - src = src + self.dropout(src_att) - - # convolution module - conv, conv_cache = self.conv_module(src, states[1], right_context) - states[1] = conv_cache - - src = src + self.dropout(conv) - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - return src, states - - -class ConformerEncoder(nn.Module): - r"""ConformerEncoder is a stack of N encoder layers - Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - """ - output = src - - for layer_index, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) - - return output - - @torch.jit.export - def chunk_forward( - self, - src: Tensor, - pos_emb: Tensor, - states: List[Tensor], - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (encoder_layers, left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: states will be modified in this function. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - left_context: - How many previous frames the attention can see in current chunk. - Note: It's not that each individual frame has `left_context` frames - of left context, some have more. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - of right context, some have more. - Shape: - src: (S, N, E). - pos_emb: (N, 2*(S+left_context)-1, E). - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - """ - assert not self.training - assert len(states) == 2 - assert states[0].shape == ( - self.num_layers, - left_context, - src.size(1), - src.size(2), - ) - assert states[1].size(0) == self.num_layers - - output = src - - for layer_index, mod in enumerate(self.layers): - cache = [states[0][layer_index], states[1][layer_index]] - output, cache = mod.chunk_forward( - output, - pos_emb, - states=cache, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - left_context=left_context, - right_context=right_context, - ) - states[0][layer_index] = cache[0] - states[1][layer_index] = cache[1] - - return output, states - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor, left_context: int = 0) -> None: - """Reset the positional encodings.""" - x_size_1 = x.size(1) + left_context - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_1 * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - """ - self.extend_pe(x, left_context) - x_size_1 = x.size(1) + left_context - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_1 - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - Examples:: - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear( - embed_dim, embed_dim, bias=True, initial_scale=0.25 - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() - - def _pos_bias_u(self): - return self.pos_bias_u * self.pos_bias_u_scale.exp() - - def _pos_bias_v(self): - return self.pos_bias_v * self.pos_bias_v_scale.exp() - - def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.01) - nn.init.normal_(self.pos_bias_v, std=0.01) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - left_context: int = 0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), - self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - left_context=left_context, - ) - - def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: - """Compute relative positional encoding. - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - - time2 = time1 + left_context - if not torch.jit.is_tracing(): - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" - - if torch.jit.is_tracing(): - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(time2) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - - x = x.reshape(-1, n) - x = torch.gather(x, dim=1, index=indexes) - x = x.reshape(batch_size, num_heads, time1, time2) - return x - else: - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - left_context: int = 0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) - p = p.permute(0, 2, 3, 1) - - q_with_bias_u = (q + self._pos_bias_u()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self._pos_bias_v()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd, left_context) - - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - - # If we are using dynamic_chunk_training and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`, at this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - causal (bool): Whether to use causal convolution. - """ - - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - causal: bool = False, - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - self.causal = causal - - self.pointwise_conv1 = ScaledConv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 - ) - - self.lorder = kernel_size - 1 - padding = (kernel_size - 1) // 2 - if self.causal: - padding = 0 - - self.depthwise_conv = ScaledConv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channel_dim=1, min_positive=0.05, max_positive=1.0 - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.25, - ) - - def forward( - self, - x: Tensor, - cache: Optional[Tensor] = None, - right_context: int = 0, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - Args: - x: Input tensor (#time, batch, channels). - cache: The cache of depthwise_conv, only used in real streaming - decoding. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - of right context, some have more. - src_key_padding_mask: the mask for the src keys per batch (optional). - Returns: - If cache is None return the output tensor (#time, batch, channels). - If cache is not None, return a tuple of Tensor, the first one is - the output tensor (#time, batch, channels), the second one is the - new cache for next chunk (#kernel_size - 1, batch, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - if self.causal and self.lorder > 0: - if cache is None: - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - else: - assert not self.training, "Cache should be None in training time" - assert cache.size(0) == self.lorder - x = torch.cat([cache.permute(1, 2, 0), x], dim=2) - if right_context > 0: - cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa - ..., - ] - else: - cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - # torch.jit.script requires return types be the same as annotated above - if cache is None: - cache = torch.empty(0) - - return x.permute(2, 0, 1), cache - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - """ - assert in_channels >= 7 - super().__init__() - - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=1, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear( - layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels - ) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(out_channels, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55 - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - Args: - x: - Its shape is (N, T, idim). - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) - return x - - -if __name__ == "__main__": - feature_dim = 50 - c = Conformer(num_features=feature_dim, d_model=128, nhead=4) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, - ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py deleted file mode 100755 index d665f3364..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ /dev/null @@ -1,918 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When training with the L subset, the offline usage: -(1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 4 \ - --avg 1 \ - --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 4 \ - --avg 1 \ - --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 4 \ - --avg 1 \ - --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 - -When training with the L subset, the streaming usage: -(1) greedy search -./pruned_transducer_stateless5/decode.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless5/exp_L_streaming \ - --use-averaged-model True \ - --max-duration 600 \ - --epoch 7 \ - --avg 1 \ - --decoding-method greedy_search \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - --decode-chunk-size 16 \ - --left-context 64 - -(2) modified beam search -./pruned_transducer_stateless5/decode.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless5/exp_L_streaming \ - --use-averaged-model True \ - --max-duration 600 \ - --epoch 7 \ - --avg 1 \ - --decoding-method modified_beam_search \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - --decode-chunk-size 16 \ - --left-context 64 - -(3) fast beam search -./pruned_transducer_stateless5/decode.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless5/exp_L_streaming \ - --use-averaged-model True \ - --max-duration 600 \ - --epoch 7 \ - --avg 1 \ - --decoding-method fast_beam_search \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - --decode-chunk-size 16 \ - --left-context 64 - -(4) modified beam search with RNNLM shallow fusion -./pruned_transducer_stateless5/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search_lm_shallow_fusion \ - --beam-size 4 \ - --lm-type rnn \ - --lm-scale 0.3 \ - --lm-exp-dir /path/to/LM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 -""" - - -import argparse -import glob -import logging -import math -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall import ContextGraph, LmScorer, NgramLm -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--context-score", - type=float, - default=2, - help=""" - The bonus score of each token for the context biasing words/phrases. - Used only when --decoding_method is modified_beam_search. - """, - ) - - parser.add_argument( - "--context-file", - type=str, - default="", - help=""" - The path of the context biasing lists, one word/phrase each line - Used only when --decoding_method is modified_beam_search. - """, - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=3, - help="""Token Ngram used for rescoring. - Used only when the decoding method is - modified_beam_search_ngram_rescoring, or LODR - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="""ID of the backoff symbol. - Used only when the decoding method is - modified_beam_search_ngram_rescoring""", - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - LM: Optional[LmScorer] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - beam=params.beam_size, - encoder_out_lens=encoder_out_lens, - context_graph=context_graph, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - key = f"beam_size_{params.beam_size}" - if params.has_contexts: - key += f"-context-score-{params.context_score}" - else: - key += "-no-context-words" - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - LM: Optional[LmScorer] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - LM: - A neural network LM, used during shallow fusion - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list("".join(text.split())) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - batch=batch, - context_graph=context_graph, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - LM=LM, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ) - - if os.path.exists(params.context_file): - params.has_contexts = True - else: - params.has_contexts = False - - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - if params.has_contexts: - params.suffix += f"-context-score-{params.context_score}" - else: - params.suffix += "-no-contexts-words" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if "ngram" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if params.use_shallow_fusion: - if params.lm_type == "rnn": - params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" - elif params.lm_type == "transformer": - params.suffix += f"-transformer-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - # only load N-gram LM when needed - if "ngram" in params.decoding_method or "LODR" in params.decoding_method: - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - # import pdb; pdb.set_trace() - # only load the neural network LM if doing shallow fusion - if params.use_shallow_fusion: - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - - num_param = sum([p.numel() for p in LM.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - else: - LM = None - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - if params.decoding_method == "modified_beam_search": - if os.path.exists(params.context_file): - contexts_text = [] - for line in open(params.context_file).readlines(): - contexts_text.append(line.strip()) - contexts = graph_compiler.texts_to_ids(contexts_text) - context_graph = ContextGraph(params.context_score) - context_graph.build([(c, 0.0) for c in contexts]) - else: - context_graph = None - else: - context_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) - - dev_cuts = wenetspeech.valid_cuts() - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dls = [dev_dl, test_net_dl, test_meeting_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - context_graph=context_graph, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - LM=LM, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py deleted file mode 100644 index e522943c0..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - if params.decoding_method == "fast_beam_search": - assert decoding_graph is not None - assert device == decoding_graph.device - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, after subsampling (i.e. a - # cumulative sum of the second return value of - # encoder.streaming_forward - self.done_frames: int = 0 - - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 - - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.params.decoding_method == "greedy_search": - return self.hyp[self.params.context_size :] # noqa - elif self.params.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.params.context_size :] # noqa - else: - assert self.params.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decoder.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decoder.py deleted file mode 120000 index 6775ee67e..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/decoder.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/encoder_interface.py deleted file mode 120000 index 972e44ca4..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/encoder_interface.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py deleted file mode 100755 index 30068d01a..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ /dev/null @@ -1,680 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_7_avg_1.pt" - -cd exp -ln -s pretrained_epoch_7_avg_1.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless5/export-onnx-streaming.py \ - --lang-dir $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir $repo/exp \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained-streaming.py for how to -use the exported ONNX models. - -You can find the exported models in -https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-zh-2023-05-23 -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import num_tokens, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Conformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Conformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - self.num_encoder_layers = encoder.encoder_layers - self.encoder_dim = encoder.d_model - self.cnn_module_kernel = encoder.cnn_module_kernel - - # Note you can tune these values - self.left_context = 64 # after subsampling - self.chunk_size = 16 # after subsampling - self.right_context = 0 # after subsampling - - subsampling_factor = 4 - self.pad_length = (self.right_context + 2) * subsampling_factor + 3 - - self.T = (self.chunk_size * subsampling_factor) + self.pad_length - self.decode_chunk_len = self.chunk_size * subsampling_factor - - def forward( - self, - x: torch.Tensor, - cached_attn: torch.Tensor, - cached_conv: torch.Tensor, - processed_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Conformer.forward - - Args: - x: - A 3-D tensor of shape (N, self.T, C) - cached_attn: - A 3-D tensor of shape - (num_encoder_layers, self.left_context, N, self.encoder_dim) - cached_conv: - A 3-D tensor of shape - (num_encoder_layers, self.cnn_module_kernel-1, N, self.encoder_dim) - processed_lens: - A 1-D tensor of shape (N,). It contains number of processed frames - after subsampling. Its dtype is torch.int64. - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, self.chunk_size, joiner_dim) - - new_cached_attn, it has the same shape as cached_attn - - new_cached_conv, it has the same shape as cached_conv - """ - assert x.size(1) == self.T, (x.shape, self.T) - N = x.size(0) - x_lens = torch.full((N,), fill_value=self.T, device=x.device, dtype=torch.int64) - - ( - encoder_out, - _, - [new_cached_attn, new_cached_conv], - ) = self.encoder.streaming_forward( - x, - x_lens, - states=[cached_attn, cached_conv], - processed_lens=processed_lens, - left_context=self.left_context, - right_context=self.right_context, - chunk_size=self.chunk_size, - ) - - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, new_cached_attn, new_cached_conv - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - N = 1 - x = torch.zeros(N, encoder_model.T, 80, dtype=torch.float32) - cached_attn = torch.zeros( - encoder_model.num_encoder_layers, - encoder_model.left_context, - N, - encoder_model.encoder_dim, - ) - cached_conv = torch.zeros( - encoder_model.num_encoder_layers, - encoder_model.cnn_module_kernel - 1, - N, - encoder_model.encoder_dim, - ) - processed_lens = torch.zeros((N,), dtype=torch.int64) - - torch.onnx.export( - encoder_model, - (x, cached_attn, cached_conv, processed_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "cached_attn", "cached_conv", "processed_lens"], - output_names=["encoder_out", "new_cached_attn", "new_cached_conv"], - dynamic_axes={ - "x": {0: "N"}, - "cached_attn": {2: "N"}, - "cached_conv": {2: "N"}, - "processed_lens": {0: "N"}, - "encoder_out": {0: "N"}, - "new_cached_attn": {2: "N"}, - "new_cached_conv": {2: "N"}, - }, - ) - - meta_data = { - "model_type": "conformer", - "version": "1", - "model_author": "k2-fsa", - "comment": "stateless5", - "pad_length": str(encoder_model.pad_length), - "decode_chunk_len": str(encoder_model.decode_chunk_len), - "encoder_dim": str(encoder_model.encoder_dim), - "num_encoder_layers": str(encoder_model.num_encoder_layers), - "cnn_module_kernel": str(encoder_model.cnn_module_kernel), - "left_context": str(encoder_model.left_context), - "right_context": str(encoder_model.right_context), - "chunk_size": str(encoder_model.chunk_size), - "T": str(encoder_model.T), - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if not params.causal_convolution: - logging.info("Seting causal_convolution to True for exporting streaming models") - params.causal_convolution = True - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum(p.numel() for p in model.parameters()) - logging.info(f"Number of model parameters: {num_param}") - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py deleted file mode 100755 index 1c9eb8648..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ /dev/null @@ -1,603 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_9_avg_1.pt" - -cd exp -ln -s pretrained_epoch_9_avg_1.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless5/export-onnx.py \ - --tokens $repo/data/lang_char/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir $repo/exp \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py and ./onnx_check.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Conformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Conformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Conformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - encoder_out, encoder_out_lens = self.encoder(x, x_lens) - - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "conformer", - "version": "1", - "model_author": "k2-fsa", - "comment": "stateless5", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py deleted file mode 100755 index 5ff1f4a3b..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage for offline: -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --tokens data/lang_char/tokens.txt \ - --epoch 4 \ - --avg 1 - -It will generate a file exp_dir/pretrained.pt for offline ASR. - -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --tokens data/lang_char/tokens.txt \ - --epoch 4 \ - --avg 1 \ - --jit True - -It will generate a file exp_dir/cpu_jit.pt for offline ASR. - -Usage for streaming: -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --tokens data/lang_char/tokens.txt \ - --epoch 7 \ - --avg 1 - -It will generate a file exp_dir/pretrained.pt for streaming ASR. - -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --tokens data/lang_char/tokens.txt \ - --epoch 7 \ - --avg 1 \ - --jit True - -It will generate a file exp_dir/cpu_jit.pt for streaming ASR. - -To use the generated file with `pruned_transducer_stateless5/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/wenetspeech/ASR - ./pruned_transducer_stateless5/decode.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --epoch 4 \ - --avg 1 \ - --decoding-method greedy_search \ - --max-duration 100 \ - --lang-dir data/lang_char -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - convert_scaled_to_non_scaled(model, inplace=True) - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/joiner.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/joiner.py deleted file mode 120000 index f5279e151..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/joiner.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py deleted file mode 120000 index d13a1e063..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/lstmp.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/model.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/model.py deleted file mode 120000 index 7b417fd89..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/model.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py deleted file mode 100755 index 8c192913e..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py +++ /dev/null @@ -1,235 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_4_avg_1.pt" -git lfs pull --include "exp/cpu_jit_epoch_4_avg_1_torch.1.7.1.pt" - -cd exp -ln -s pretrained_epoch_9_avg_1_torch.1.7.1.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless5/export-onnx.py \ - --lang-dir $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir $repo/exp \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -4. Run this file - -./pruned_transducer_stateless5/onnx_check.py \ - --jit-filename $repo/exp/cpu_jit_epoch_4_avg_1_torch.1.7.1.pt \ - --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx -""" - -import argparse -import logging - -import torch -from onnx_pretrained import OnnxModel - -from icefall import is_module_available - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - return parser - - -def test_encoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - C = 80 - for i in range(3): - N = torch.randint(low=1, high=20, size=(1,)).item() - T = torch.randint(low=30, high=50, size=(1,)).item() - logging.info(f"test_encoder: iter {i}, N={N}, T={T}") - - x = torch.rand(N, T, C) - x_lens = torch.randint(low=30, high=T + 1, size=(N,)) - x_lens[0] = T - - torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) - torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) - - onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) - - assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( - (torch_encoder_out - onnx_encoder_out).abs().max() - ) - - -def test_decoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - context_size = onnx_model.context_size - vocab_size = onnx_model.vocab_size - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_decoder: iter {i}, N={N}") - x = torch.randint( - low=1, - high=vocab_size, - size=(N, context_size), - dtype=torch.int64, - ) - torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) - torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) - torch_decoder_out = torch_decoder_out.squeeze(1) - - onnx_decoder_out = onnx_model.run_decoder(x) - assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( - (torch_decoder_out - onnx_decoder_out).abs().max() - ) - - -def test_joiner( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] - decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_joiner: iter {i}, N={N}") - encoder_out = torch.rand(N, encoder_dim) - decoder_out = torch.rand(N, decoder_dim) - - projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) - projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) - - torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) - onnx_joiner_out = onnx_model.run_joiner( - projected_encoder_out, projected_decoder_out - ) - - assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( - (torch_joiner_out - onnx_joiner_out).abs().max() - ) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_model = torch.jit.load(args.jit_filename) - - onnx_model = OnnxModel( - encoder_model_filename=args.onnx_encoder_filename, - decoder_model_filename=args.onnx_decoder_filename, - joiner_model_filename=args.onnx_joiner_filename, - ) - - logging.info("Test encoder") - test_encoder(torch_model, onnx_model) - - logging.info("Test decoder") - test_decoder(torch_model, onnx_model) - - logging.info("Test joiner") - test_joiner(torch_model, onnx_model) - logging.info("Finished checking ONNX models") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -# See https://github.com/pytorch/pytorch/issues/38342 -# and https://github.com/pytorch/pytorch/issues/33354 -# -# If we don't do this, the delay increases whenever there is -# a new request that changes the actual batch size. -# If you use `py-spy dump --pid --native`, you will -# see a lot of time is spent in re-compiling the torch script model. -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py deleted file mode 100755 index cca26feb0..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ /dev/null @@ -1,461 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - -""" -This script loads ONNX models exported by ./export-onnx.py -and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_7_avg_1.pt" - -cd exp -ln -s pretrained_epoch_7_avg_1.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless5/export-onnx-streaming.py \ - --lang-dir $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir $repo/exp \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file with the exported ONNX models - -./pruned_transducer_stateless5/onnx_pretrained-streaming.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_char/tokens.txt \ - $repo/test_wavs/DEV_T0000000000.wav - -Note: Even though this script only supports decoding a single file, -the exported ONNX models do support batch processing. - -You can find the exported models in -https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-zh-2023-05-23 -""" - -import argparse -import logging -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import onnxruntime as ort -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - self.init_encoder_states() - - def init_encoder_states(self, batch_size: int = 1): - encoder_meta = self.encoder.get_modelmeta().custom_metadata_map - print(encoder_meta) - - model_type = encoder_meta["model_type"] - assert model_type == "conformer", model_type - - decode_chunk_len = int(encoder_meta["decode_chunk_len"]) - T = int(encoder_meta["T"]) - pad_length = int(encoder_meta["pad_length"]) - - encoder_dim = int(encoder_meta["encoder_dim"]) - cnn_module_kernel = int(encoder_meta["cnn_module_kernel"]) - left_context = int(encoder_meta["left_context"]) - num_encoder_layers = int(encoder_meta["num_encoder_layers"]) - - self.cached_attn = torch.zeros( - num_encoder_layers, - left_context, - batch_size, - encoder_dim, - ).numpy() - self.cached_conv = torch.zeros( - num_encoder_layers, - cnn_module_kernel - 1, - batch_size, - encoder_dim, - ).numpy() - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - logging.info(f"pad_length: {pad_length}") - logging.info(f"encoder_dim: {encoder_dim}") - logging.info(f"cnn_module_kernel: {cnn_module_kernel}") - logging.info(f"left_context: {left_context}") - logging.info(f"num_encoder_layers: {num_encoder_layers}") - - self.segment = T - self.offset = decode_chunk_len - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def _build_encoder_input_output( - self, x: torch.Tensor, processed_lens: int - ) -> Tuple[Dict[str, np.ndarray], List[str]]: - assert x.size(0) == 1 - encoder_input = { - "x": x.numpy(), - "cached_attn": self.cached_attn, - "cached_conv": self.cached_conv, - "processed_lens": torch.full( - (1,), fill_value=processed_lens, dtype=torch.int64 - ).numpy(), - } - encoder_output = ["encoder_out", "new_cached_attn", "new_cached_conv"] - - return encoder_input, encoder_output - - def _update_states(self, states: List[np.ndarray]): - self.cached_attn = states[0] - self.cached_conv = states[1] - - def run_encoder(self, x: torch.Tensor, num_processed_frames: int) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, self.T, C). It only implements N == 1 - num_processed_frames: - Number of processed frames before subsampling. - Returns: - Return a 3-D tensor of shape (N, chunk_size, joiner_dim) - """ - # assume subsampling_factor is 4 - num_processed_frames = num_processed_frames // 4 - encoder_input, encoder_output_names = self._build_encoder_input_output( - x, num_processed_frames - ) - out = self.encoder.run(encoder_output_names, encoder_input) - - self._update_states(out[1:]) - - return torch.from_numpy(out[0]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def create_streaming_feature_extractor() -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - return OnlineFbank(opts) - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - context_size: int, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -) -> List[int]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (1, T, joiner_dim) - context_size: - The context size of the decoder model. - decoder_out: - Optional. Decoder output of the previous chunk. - hyp: - Decoding results for previous chunks. - Returns: - Return the decoded results so far. - """ - - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - else: - assert hyp is not None, hyp - - encoder_out = encoder_out.squeeze(0) - T = encoder_out.size(0) - for t in range(T): - cur_encoder_out = encoder_out[t : t + 1] - joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - decoder_input = torch.tensor([decoder_input], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - - return hyp, decoder_out - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - sample_rate = 16000 - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor() - - logging.info(f"Reading sound files: {args.sound_file}") - waves = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=sample_rate, - )[0] - - tail_padding = torch.zeros(int(1.0 * sample_rate), dtype=torch.float32) - wave_samples = torch.cat([waves, tail_padding]) - - num_processed_frames = 0 - segment = model.segment - offset = model.offset - - context_size = model.context_size - hyp = None - decoder_out = None - - chunk = int(1 * sample_rate) # 1 second - start = 0 - while start < wave_samples.numel(): - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - - online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, - ) - - while online_fbank.num_frames_ready - num_processed_frames >= segment: - frames = [] - for i in range(segment): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - num_processed_frames += offset - frames = torch.cat(frames, dim=0) - frames = frames.unsqueeze(0) - encoder_out = model.run_encoder(frames, num_processed_frames) - hyp, decoder_out = greedy_search( - model, - encoder_out, - context_size, - decoder_out, - hyp, - ) - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - text = "" - for i in hyp[context_size:]: - text += symbol_table[i] - text = text.replace("▁", " ").strip() - - logging.info(args.sound_file) - logging.info(text) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py deleted file mode 100755 index 4b4ddd332..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py +++ /dev/null @@ -1,429 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_4_avg_1.pt" -git lfs pull --include "exp/cpu_jit_epoch_4_avg_1_torch.1.7.1.pt" - -cd exp -ln -s pretrained_epoch_9_avg_1_torch.1.7.1.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless5/export-onnx.py \ - --lang-dir $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir $repo/exp \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file - -./pruned_transducer_stateless5/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_char/tokens.txt \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, joiner_dim) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.run_decoder(decoder_input) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - # current_encoder_out's shape: (batch_size, joiner_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - logits = model.run_joiner(current_encoder_out, decoder_out) - - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = model.run_decoder(decoder_input) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += symbol_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/optim.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/optim.py deleted file mode 120000 index 210374f22..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py deleted file mode 100644 index 17428e19d..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py +++ /dev/null @@ -1,335 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Offline Usage: -(1) greedy search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp_L_offline/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - --max-sym-per-frame 1 \ - /path/to/foo.wav \ - /path/to/bar.wav -(2) modified beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp_L_offline/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav -(3) fast beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp_L_offline/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - /path/to/foo.wav \ - /path/to/bar.wav -You can also use `./pruned_transducer_stateless5/exp_L_offline/epoch-xx.pt`. -Note: ./pruned_transducer_stateless5/exp_L_offline/pretrained.pt is generated by -./pruned_transducer_stateless5/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=48000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --method is beam_search and modified_beam_search ", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - hyps = [] - msg = f"Using {params.decoding_method}" - logging.info(msg) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling.py deleted file mode 120000 index ff7bfeda9..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py deleted file mode 120000 index e58473a04..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py deleted file mode 100644 index 810d94135..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List - -import k2 -import torch -import torch.nn as nn -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from decode_stream import DecodeStream - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - # print(current_encoder_out.shape) - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - num_active_paths: int = 4, -) -> None: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - num_active_paths: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - streams: List[DecodeStream], - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - A lattice is first generated by Fsa-based beam search, then we get the - recognition by applying shortest path on the lattice. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - streams: - A list of stream objects. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - B, T, C = encoder_out.shape - assert B == len(streams) - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyp_tokens[i] diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py deleted file mode 100644 index b396aa9b8..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ /dev/null @@ -1,674 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -(1) greedy search -python pruned_transducer_stateless5/streaming_decode.py \ - --epoch 7 \ - --avg 1 \ - --decode-chunk-size 16 \ - --left-context 64 \ - --right-context 0 \ - --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --decoding-method greedy_search \ - --num-decode-streams 2000 - -(2) modified beam search -python pruned_transducer_stateless5/streaming_decode.py \ - --epoch 7 \ - --avg 1 \ - --decode-chunk-size 16 \ - --left-context 64 \ - --right-context 0 \ - --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --decoding-method modified_beam_search \ - --num-decode-streams 2000 - -(3) fast beam search -python pruned_transducer_stateless5/streaming_decode.py \ - --epoch 7 \ - --avg 1 \ - --decode-chunk-size 16 \ - --left-context 64 \ - --right-context 0 \ - --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --decoding-method fast_beam_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num-active-paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--right-context", - type=int, - default=0, - help="right context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames( - params.decode_chunk_size * params.subsampling_factor - ) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # if T is less than 7 there will be an error in time reduction layer, - # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - # we plus 2 here because we will cut off one frame on each size of - # encoder_embed output as they see invalid paddings. so we need extra 2 - # frames. - tail_length = 7 + (2 + params.right_context) * params.subsampling_factor - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = [ - torch.stack([x[0] for x in states], dim=2), - torch.stack([x[1] for x in states], dim=2), - ] - - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - left_context=params.left_context, - right_context=params.right_context, - processed_lens=processed_lens, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = [states[0][i], states[1][i]] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - decode_stream.set_features(fbank(samples.to(device))) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].decoding_result() - decode_results.append( - ( - decode_streams[i].id, - list(decode_streams[i].ground_truth), - [lexicon.token_table[idx] for idx in hyp], - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - hyp = decode_streams[i].decoding_result() - decode_results.append( - ( - decode_streams[i].id, - list(decode_streams[i].ground_truth), - [lexicon.token_table[idx] for idx in hyp], - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - # sort results so we can easily compare the difference between two - # recognition results - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - params.suffix += f"-right-context-{params.right_context}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - params.causal_convolution = True - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - wenetspeech = WenetSpeechAsrDataModule(args) - - dev_cuts = wenetspeech.valid_cuts() - test_net_cuts = wenetspeech.test_net_cuts() - test_meeting_cuts = wenetspeech.test_meeting_cuts() - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py deleted file mode 100755 index 931e699d9..000000000 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ /dev/null @@ -1,1205 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage for offline ASR: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless5/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless5/exp_L_offline \ - --world-size 8 \ - --num-epochs 15 \ - --start-epoch 2 \ - --max-duration 120 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --average-period 1000 \ - --training-subset L - -Usage for streaming ASR: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless5/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless5/exp_L_streaming \ - --world-size 8 \ - --num-epochs 15 \ - --start-epoch 1 \ - --max-duration 140 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --average-period 1000 \ - --training-subset L \ - --dynamic-chunk-training True \ - --causal-convolution True \ - --short-chunk-size 25 \ - --num-left-chunks 4 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--valid-interval", - type=int, - default=3000, - help="""When training_subset is L, set the valid_interval to 3000. - When training_subset is M, set the valid_interval to 1000. - When training_subset is S, set the valid_interval to 400. - """, - ) - - parser.add_argument( - "--model-warm-step", - type=int, - default=3000, - help="""When training_subset is L, set the model_warm_step to 3000. - When training_subset is M, set the model_warm_step to 500. - When training_subset is S, set the model_warm_step to 100. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - dynamic_chunk_training=params.dynamic_chunk_training, - short_chunk_size=params.short_chunk_size, - num_left_chunks=params.num_left_chunks, - causal=params.causal_convolution, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - - y = graph_compiler.texts_to_ids(texts) - if isinstance(y, list): - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - if params.dynamic_chunk_training: - assert ( - params.causal_convolution - ), "dynamic_chunk_training requires causal convolution" - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - wenetspeech = WenetSpeechAsrDataModule(args) - - train_cuts = wenetspeech.train_cuts() - valid_cuts = wenetspeech.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 10 seconds - # - # Caution: There is a reason to select 10.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 10.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = c.supervisions[0].text.replace(" ", "") - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - valid_dl = wenetspeech.valid_dataloaders(valid_cuts) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = wenetspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - if not params.print_diagnostics and params.start_batch == 0: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - texts = batch["supervisions"]["text"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(texts) - if type(y) == list: - y = k2.RaggedTensor(y) - - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/shared b/egs/wenetspeech/ASR/shared deleted file mode 120000 index e9461a6d7..000000000 --- a/egs/wenetspeech/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../librispeech/ASR/shared \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/asr_datamodule.py b/egs/wenetspeech/ASR/whisper/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/wenetspeech/ASR/whisper/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/decode.py b/egs/wenetspeech/ASR/whisper/decode.py deleted file mode 100755 index 34b1c80ef..000000000 --- a/egs/wenetspeech/ASR/whisper/decode.py +++ /dev/null @@ -1,526 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Wei Kang) -# 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -# Command for decoding using fine-tuned models: -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper -ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch 999 --avg 1 \ - --beam-size 10 --max-duration 50 - -# Command for decoding using pretrained models (before fine-tuning): - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch -1 --avg 1 \ - --remove-whisper-encoder-input-length-restriction False \ - --beam-size 10 --max-duration 50 - -""" - -import argparse -import logging -import re -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -import whisper -from asr_datamodule import WenetSpeechAsrDataModule -from lhotse.cut import Cut -from tn.chinese.normalizer import Normalizer -from whisper.normalizers import BasicTextNormalizer -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from zhconv import convert - -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def average_checkpoints( - filenames: List[Path], device: torch.device = torch.device("cpu") -) -> dict: - """Average a list of checkpoints. - The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. - - Args: - filenames: - Filenames of the checkpoints to be averaged. We assume all - checkpoints are saved by :func:`save_checkpoint`. - device: - Move checkpoints to this device before averaging. - Returns: - Return a dict (i.e., state_dict) which is the average of all - model state dicts contained in the checkpoints. - """ - n = len(filenames) - - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] - else: - avg = torch.load(filenames[0], map_location=device) - - # Identify shared parameters. Two parameters are said to be shared - # if they have the same data_ptr - uniqued: Dict[int, str] = dict() - - for k, v in avg.items(): - v_data_ptr = v.data_ptr() - if v_data_ptr in uniqued: - continue - uniqued[v_data_ptr] = k - - uniqued_names = list(uniqued.values()) - - for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] - else: - state_dict = torch.load(filenames[i], map_location=device) - for k in uniqued_names: - avg[k] += state_dict[k] - - for k in uniqued_names: - if avg[k].is_floating_point(): - avg[k] /= n - else: - avg[k] //= n - - return avg - - -def remove_punctuation(text: str or List[str]): - """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py - - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings without any punctuation. - """ - punctuation = "!,.;:?、!,。;:?《》 " - if isinstance(text, str): - text = re.sub(r"[{}]+".format(punctuation), "", text).strip() - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = re.sub(r"[{}]+".format(punctuation), "", t).strip() - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type {type(text)}") - - -def to_simple(text: str or List[str]): - """Convert traditional Chinese to simplified Chinese. - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings converted to simplified Chinese. - """ - if isinstance(text, str): - text = convert(text, "zh-cn") - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = convert(t, "zh-cn") - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type{type(text)}") - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=-1, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--method", - type=str, - default="beam-search", - help="""Decoding method. - Supported values are: - - beam-search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=1, - help="beam size for beam search decoding", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="whisper/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], - help="""The model name to use. - """, - ) - - parser.add_argument( - "--remove-whisper-encoder-input-length-restriction", - type=str2bool, - default=True, - help="replace whisper encoder forward method to remove input length restriction", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: "beam-search" - - value: A list of lists. Each sublist is a list of token IDs. - Args: - params: - It is returned by :func:`get_params`. - model: - The neural model. - batch: - It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. - Returns: - Return a dict, whose key may be "beam-search". - """ - dtype = torch.float16 - device = torch.device("cuda") - - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device, dtype=dtype).transpose(1, 2) - if not params.remove_whisper_encoder_input_length_restriction: - T = 3000 - if feature.shape[2] < T: - feature = torch.cat( - [ - feature, - torch.zeros( - feature.shape[0], feature.shape[1], T - feature.shape[2] - ).to(device, dtype=dtype), - ], - 2, - ) - - supervisions = batch["supervisions"] - feature_len = supervisions["num_frames"] - feature_len = feature_len.to(device, dtype=dtype) - results = model.decode(feature, params.decoding_options) - hyps = [result.text for result in results] - - hyps = remove_punctuation(hyps) - hyps = to_simple(hyps) - hyps = [params.normalizer.normalize(hyp) for hyp in hyps] - print(hyps) - return {"beam-search": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - The dataloader. - params: - It is returned by :func:`get_params`. - model: - The neural model. - Returns: - Return a dict, whose key may be "beam-search". - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - batch=batch, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" - ) - - options = whisper.DecodingOptions( - task="transcribe", - language="zh", - without_timestamps=True, - beam_size=params.beam_size, - ) - params.decoding_options = options - params.cleaner = BasicTextNormalizer() - params.normalizer = Normalizer() - - logging.info("Decoding started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda") - - logging.info(f"device: {device}") - - if params.remove_whisper_encoder_input_length_restriction: - replace_whisper_encoder_forward() - model = whisper.load_model(params.model_name, "cpu") - if params.epoch > 0: - if params.avg > 1: - start = params.epoch - params.avg - assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - # deepspeed converted checkpoint only contains model state_dict - filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" - for epoch in range(start, params.epoch + 1) - ] - model.load_state_dict(average_checkpoints(filenames)) - else: - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - # save checkpoints - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(model.state_dict(), filename) - else: - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - model.load_state_dict(checkpoint, strict=True) - else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) - dev_cuts = wenetspeech.valid_cuts() - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - - def remove_long_utt(c: Cut): - # Keep only utterances with duration in 30 seconds - # - if c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_long_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - # test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - # test_dls = [dev_dl, test_net_dl, test_meeting_dl] - - test_sets = ["TEST_NET"] - test_dls = [test_net_dl] - - # test_sets = ["TEST_MEETING"] - # test_dls = [test_meeting_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/whisper/ds_config_zero1.json b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json deleted file mode 120000 index af7162d6c..000000000 --- a/egs/wenetspeech/ASR/whisper/ds_config_zero1.json +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/ds_config_zero1.json \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/label_smoothing.py b/egs/wenetspeech/ASR/whisper/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/wenetspeech/ASR/whisper/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/optim.py b/egs/wenetspeech/ASR/whisper/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/wenetspeech/ASR/whisper/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/requirements.txt b/egs/wenetspeech/ASR/whisper/requirements.txt deleted file mode 120000 index 744bf8bb6..000000000 --- a/egs/wenetspeech/ASR/whisper/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/requirements.txt \ No newline at end of file diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py deleted file mode 100644 index 4e55fd6a8..000000000 --- a/egs/wenetspeech/ASR/whisper/train.py +++ /dev/null @@ -1,959 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) -# 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -#fine-tuning with deepspeed zero stage 1 -torchrun --nproc-per-node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json - -# fine-tuning with ddp -torchrun --nproc_per_node 8 ./whisper/train.py \ - --max-duration 200 \ - --exp-dir whisper/exp_medium \ - --base-lr 1e-5 \ - --model-name medium -""" - - -import argparse -import copy -import logging -import os -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import deepspeed -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import whisper -from asr_datamodule import WenetSpeechAsrDataModule -from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict -from label_smoothing import LabelSmoothingLoss -from lhotse import CutSet, load_manifest -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.functional import pad as pad_tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import update_averaged_model -from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=10, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="whisper/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], - help="""The model name to use. - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=1e-5, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - parser = deepspeed.add_config_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - frame_shift_ms: The frame shift in milliseconds. - - allowed_excess_duration_ratio: The allowed excess duration ratio. - - best_train_loss: The best training loss so far. - - best_valid_loss: The best validation loss so far. - - best_train_epoch: The epoch where the best training loss is achieved. - - best_valid_epoch: The epoch where the best validation loss is achieved. - - batch_idx_train: The batch index of the current batch. - - log_interval: Log training stats every `log_interval` batches. - - reset_interval: Reset the stats every `reset_interval` batches. - - valid_interval: Run validation every `valid_interval` batches. - - env_info: The environment information. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "subsampling_factor": 2, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 10000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute the loss for the given batch. - Args: - params: - It is returned by :func:`get_params`. - tokenizer: - The tokenizer used to encode the text. - model: - The model for training. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - Whether it is training. - Returns: - Return a tuple of two elements. The first element is the loss tensor. - """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - - def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: - padding_size = max(tensor.shape[0] for tensor in tensors) - dims = len(tensors[0].shape) - padded_tensors = [] - for tensor in tensors: - padding = [0] * 2 * dims - padding[-1] = padding_size - tensor.shape[0] - padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) - return torch.stack([tensor for tensor in padded_tensors], dim=0) - - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - - assert feature.ndim == 3 - feature = feature.to(device) - feature = feature.transpose(1, 2) # (N, C, T) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - - texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [text.replace(" ", "") for text in texts] - - text_tokens_list = [ - list(tokenizer.sot_sequence_including_notimestamps) - + tokenizer.encode(text) - + [tokenizer.eot] - for text in texts - ] - # convert it to torch tensor - text_tokens_list = [ - torch.LongTensor(text_tokens) for text_tokens in text_tokens_list - ] - - # 50256 is the index of for all whisper models - prev_outputs_tokens = _batch_tensors( - [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 - ) - target_tokens = _batch_tensors( - [tokens[1:] for tokens in text_tokens_list], pad_value=50256 - ) - target_lengths = torch.LongTensor( - [tokens.shape[0] - 1 for tokens in text_tokens_list] - ) - - decoder_criterion = LabelSmoothingLoss( - ignore_index=50256, label_smoothing=0.1, reduction="sum" - ) - - # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> - ignore_prefix_size = 3 - with torch.set_grad_enabled(is_training): - encoder_out = model.encoder(feature) - text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) - text_logits = text_logits[:, ignore_prefix_size:, :] - target_tokens = target_tokens[:, ignore_prefix_size:] - loss = decoder_criterion(text_logits, target_tokens.to(device)) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - tokenizer=tokenizer, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - if params.deepspeed: - model.save_checkpoint( - save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - client_state={}, - ) - if rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - ) - os.system( - f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" - ) - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - if params.deepspeed: - # deepspeed's backward() is different from torch's backward() - # in that it does not accept a loss tensor as input. - # It computes the loss internally. - model.backward(loss) - model.step() - else: - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - and not params.deepspeed - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - if batch_idx % params.log_interval == 0: - try: - cur_lr = scheduler.get_last_lr()[0] - except: # noqa - cur_lr = 0.0 - cur_grad_scale = ( - scaler._scale.item() - if (params.use_fp16 and not params.deepspeed) - else 1.0 - ) - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if (params.use_fp16 and not params.deepspeed) - else "" - ) - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info(params) - - logging.info("About to create model") - - replace_whisper_encoder_forward() - model = whisper.load_model(params.model_name, "cpu") - del model.alignment_heads - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - tokenizer = whisper.tokenizer.get_tokenizer( - model.is_multilingual, - num_languages=model.num_languages, - language="zh", - task="transcribe", - ) - - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - else: - device = torch.device("cpu") - logging.info(f"Device: {device}") - model.to(device) - - optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if world_size > 1: - if params.deepspeed: - logging.info("Using DeepSpeed") - model, optimizer, _, scheduler = deepspeed.initialize( - args=params, model=model, model_parameters=model.parameters() - ) - else: - logging.info("Using DDP") - setup_dist(use_ddp_launch=True) - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - wenetspeech = WenetSpeechAsrDataModule(args) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15 seconds - # - # Caution: There is a reason to select 15.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 15.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - train_cuts = wenetspeech.train_cuts() - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = wenetspeech.train_dataloaders(train_cuts) - valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts()) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - logging.info(f"start training from epoch {params.start_epoch}") - for epoch in range(params.start_epoch, params.num_epochs + 1): - if not params.deepspeed: - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - tokenizer=tokenizer, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if params.deepspeed: - model.save_checkpoint( - save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", - client_state={}, - ) - if rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", - ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") - else: - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1 and not params.deepspeed: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = get_world_size() - rank = get_rank() - - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - run(rank=rank, world_size=world_size, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py deleted file mode 120000 index 2a7808921..000000000 --- a/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/__init__.py b/egs/wenetspeech/ASR/zipformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/wenetspeech/ASR/zipformer/asr_datamodule.py b/egs/wenetspeech/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index a074d6085..000000000 --- a/egs/wenetspeech/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/beam_search.py b/egs/wenetspeech/ASR/zipformer/beam_search.py deleted file mode 120000 index 8554e44cc..000000000 --- a/egs/wenetspeech/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/decode.py b/egs/wenetspeech/ASR/zipformer/decode.py deleted file mode 100755 index 0fbc8244b..000000000 --- a/egs/wenetspeech/ASR/zipformer/decode.py +++ /dev/null @@ -1,818 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (trivial_graph) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(4) fast beam search (LG) -./zipformer/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}_" + key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list("".join(text.split())) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest_oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - dev_cuts = wenetspeech.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dls = [dev_dl, test_net_dl, test_meeting_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/zipformer/decode_stream.py b/egs/wenetspeech/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/wenetspeech/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/decoder.py b/egs/wenetspeech/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/wenetspeech/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/encoder_interface.py b/egs/wenetspeech/ASR/zipformer/encoder_interface.py deleted file mode 120000 index b9aa0ae08..000000000 --- a/egs/wenetspeech/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/export-onnx-streaming.py b/egs/wenetspeech/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/wenetspeech/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/export-onnx.py b/egs/wenetspeech/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/wenetspeech/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/export.py b/egs/wenetspeech/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/wenetspeech/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/jit_pretrained.py b/egs/wenetspeech/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/wenetspeech/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/wenetspeech/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 120000 index 1962351e9..000000000 --- a/egs/wenetspeech/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/joiner.py b/egs/wenetspeech/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/wenetspeech/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/model.py b/egs/wenetspeech/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/wenetspeech/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/onnx_check.py b/egs/wenetspeech/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/wenetspeech/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/onnx_decode.py b/egs/wenetspeech/ASR/zipformer/onnx_decode.py deleted file mode 100755 index ed5f6db08..000000000 --- a/egs/wenetspeech/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1,334 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX exported models and uses them to decode the test sets. - -We use the pre-trained model from -https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/tokens.txt" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-9999.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_char/tokens.txt \ - --epoch 9999 \ - --avg 1 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-9999-avg-1.onnx - - decoder-epoch-9999-avg-1.onnx - - joiner-epoch-9999-avg-1.onnx - -2. Run this file - -./zipformer/onnx_decode.py \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ -""" - - -import argparse -import logging -import time -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from lhotse.cut import Cut -from onnx_pretrained import OnnxModel, greedy_search - -from icefall.utils import setup_logger, store_transcripts, write_error_stats - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - - return parser - - -def decode_one_batch( - model: OnnxModel, token_table: k2.SymbolTable, batch: dict -) -> List[List[str]]: - """Decode one batch and return the result. - Currently it only greedy_search is supported. - - Args: - model: - The neural model. - token_table: - Mapping ids to tokens. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - - Returns: - Return the decoded results for each utterance. - """ - feature = batch["inputs"] - assert feature.ndim == 3 - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(dtype=torch.int64) - - encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) - - hyps = greedy_search( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) - - hyps = [[token_table[h] for h in hyp] for hyp in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: nn.Module, - token_table: k2.SymbolTable, -) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - The neural model. - token_table: - Mapping ids to tokens. - - Returns: - - A list of tuples. Each tuple contains three elements: - - cut_id, - - reference transcript, - - predicted result. - - The total duration (in seconds) of the dataset. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 10 - total_duration = 0 - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = list(ref_text) - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return results, total_duration - - -def save_results( - res_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -): - recog_path = res_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - errs_info = res_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("WER", file=f) - print(wer, file=f) - - s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - assert ( - args.decoding_method == "greedy_search" - ), "Only supports greedy_search currently." - res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" - - setup_logger(f"{res_dir}/log-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - logging.info(f"Device: {device}") - - token_table = k2.SymbolTable.from_file(args.tokens) - assert token_table[0] == "" - - logging.info(vars(args)) - - logging.info("About to create model") - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - # we need cut ids to display recognition results. - args.return_cuts = True - - wenetspeech = WenetSpeechAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - dev_cuts = wenetspeech.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - start_time = time.time() - results, total_duration = decode_dataset( - dl=test_dl, model=model, token_table=token_table - ) - end_time = time.time() - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - - logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") - logging.info(f"Wave duration: {total_duration:.3f} s") - logging.info( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - - save_results(res_dir=res_dir, test_set_name=test_set, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 120000 index cfea104c2..000000000 --- a/egs/wenetspeech/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/onnx_pretrained.py b/egs/wenetspeech/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/wenetspeech/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/optim.py b/egs/wenetspeech/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/wenetspeech/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/pretrained.py b/egs/wenetspeech/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/wenetspeech/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/scaling.py b/egs/wenetspeech/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/wenetspeech/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/scaling_converter.py b/egs/wenetspeech/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/wenetspeech/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/streaming_beam_search.py b/egs/wenetspeech/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/wenetspeech/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py deleted file mode 100755 index cb2cf7d35..000000000 --- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,882 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -from asr_datamodule import WenetSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="Path to the lang dir(containing lexicon, tokens, etc.)", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, - encoder_out=encoder_out, - streams=decode_streams, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - lexicon: - The Lexicon. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - # - this is to avoid sending [-32k,+32k] signal in... - # - some lhotse AudioTransform classes can make the signal - # be out of range [-1, 1], hence the tolerance 10 - assert ( - np.abs(audio).max() <= 10 - ), "Should be normalized to [-1, 1], 10 for tolerance..." - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - list(decode_streams[i].ground_truth.strip()), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], - ) - ) - del decode_streams[i] - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - key = f"greedy_search_{key}" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_{key}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}_{key}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - wenetspeech = WenetSpeechAsrDataModule(args) - - dev_cuts = wenetspeech.valid_cuts() - test_net_cuts = wenetspeech.test_net_cuts() - test_meeting_cuts = wenetspeech.test_meeting_cuts() - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - lexicon=lexicon, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/zipformer/subsampling.py b/egs/wenetspeech/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/wenetspeech/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py deleted file mode 100755 index 25b16f632..000000000 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ /dev/null @@ -1,1350 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 12 \ - --start-epoch 1 \ - --exp-dir zipformer/exp \ - --training-subset L - --lr-epochs 1.5 \ - --max-duration 350 - -# For mix precision training: - -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 12 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --training-subset L \ - --lr-epochs 1.5 \ - --max-duration 750 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="""Maximum left-contexts for causal training, measured in frames which will - be converted to a number of chunks. If splitting into chunks, - chunk left-context frames will be chosen randomly from this list; else not relevant.""", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="""The prune range for rnnt loss, it means how many symbols(context) - we are using to compute the loss""", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="""The scale to smooth the loss with lm - (output of prediction network) part.""", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="""The scale to smooth the loss with am (output of encoder network) part.""", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="""To get pruning ranges, we will calculate a simple version - loss(joiner is just addition), this simple loss also uses for - training (as a regularization item). We will scale the simple loss - with this parameter before adding to the final loss.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss = losses[:2] - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - wenetspeech = WenetSpeechAsrDataModule(args) - - train_cuts = wenetspeech.train_cuts() - valid_cuts = wenetspeech.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15 seconds - # - # Caution: There is a reason to select 15.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 15.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = wenetspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_dl = wenetspeech.valid_dataloaders(valid_cuts) - - if False and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - graph_compiler: - The compiler to encode texts to ids. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/zipformer/zipformer.py b/egs/wenetspeech/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/wenetspeech/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/RESULTS.md b/egs/wenetspeech/KWS/RESULTS.md deleted file mode 100644 index 29da3e2e5..000000000 --- a/egs/wenetspeech/KWS/RESULTS.md +++ /dev/null @@ -1,58 +0,0 @@ -# Results - -## zipformer transducer model - -This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details. - -The modeling units are partial pinyin (i.e initials and finals) with tone. - -The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test net of wenetspeech (has 23 hours audios). - -We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands. - -The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-wenetspeech-20240219.tar.gz). - -Here is the results of a small test set which has 20 commands, we list the results of every commands, for -each metric there are two columns, one for the original model trained on wenetspeech L subset, the other -for the finetune model finetuned on in house commands dataset (has 90 hours audio). - -> You can see that the performance of the original model is very poor, I think the reason is the test commands are all collected from real product scenarios which are very different from the scenarios wenetspeech dataset was collected. After finetuning, the performance improves a lot. - -Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours --- | -- | -- | -- | -- | -- | -- | -- | -- -  | original | finetune | original | finetune | original | finetune | original | finetune -All | 426 / 985 | 40/985 | 56.8% | 95.9% | 7 | 1 | 0.3 | 0.04 -下一个 | 5/50 | 0/50 | 90% | 100% | 3 | 0 | 0.13 | 0 -开灯 | 19/49 | 2/49 | 61.2% | 95.9% | 0 | 0 | 0 | 0 -第一个 | 11/50 | 3/50 | 78% | 94% | 3 | 0 | 0.13 | 0 -声音调到最大 | 39/50 | 7/50 | 22% | 86% | 0 | 0 | 0 | 0 -暂停音乐 | 36/49 | 1/49 | 26.5% | 98% | 0 | 0 | 0 | 0 -暂停播放 | 33/49 | 2/49 | 32.7% | 95.9% | 0 | 0 | 0 | 0 -打开卧室灯 | 33/49 | 1/49 | 32.7% | 98% | 0 | 0 | 0 | 0 -关闭所有灯 | 27/50 | 0/50 | 46% | 100% | 0 | 0 | 0 | 0 -关灯 | 25/48 | 2/48 | 47.9% | 95.8% | 1 | 1 | 0.04 | 0.04 -关闭导航 | 25/48 | 1/48 | 47.9% | 97.9% | 0 | 0 | 0 | 0 -打开蓝牙 | 24/47 | 0/47 | 48.9% | 100% | 0 | 0 | 0 | 0 -下一首歌 | 21/50 | 1/50 | 58% | 98% | 0 | 0 | 0 | 0 -换一首歌 | 19/50 | 5/50 | 62% | 90% | 0 | 0 | 0 | 0 -继续播放 | 19/50 | 2/50 | 62% | 96% | 0 | 0 | 0 | 0 -打开闹钟 | 18/49 | 2/49 | 63.3% | 95.9% | 0 | 0 | 0 | 0 -打开音乐 | 17/49 | 0/49 | 65.3% | 100% | 0 | 0 | 0 | 0 -打开导航 | 17/48 | 0/49 | 64.6% | 100% | 0 | 0 | 0 | 0 -打开电视 | 15/50 | 0/49 | 70% | 100% | 0 | 0 | 0 | 0 -大点声 | 12/50 | 5/50 | 76% | 90% | 0 | 0 | 0 | 0 -小点声 | 11/50 | 6/50 | 78% | 88% | 0 | 0 | 0 | 0 - - -This is the result of large test set, it has more than 100 commands, too many to list the details of each commands, so only an overall result here. We also list the results of two weak up words 小云小云 (only test set)and 你好问问 (both training and test sets). For 你好问问, we have to finetune models, one is finetuned on 你好问问 and our in house commands data, the other finetuned on only 你好问问. Both models perform much better than original model, the one finetuned on only 你好问问 behaves slightly better than the other. - -> 小云小云 test set and 你好问问 training, dev and test sets are available at https://github.com/pkufool/open-commands - -Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours --- | -- | -- | -- | -- | -- | -- | -- | -- -  | original | finetune | original | finetune | original | finetune | original | finetune -large | 2429/4505 | 477 / 4505 | 46.1% | 89.4% | 50 | 41 | 2.17 | 1.78 -小云小云(clean) | 30/100 | 40/100 | 70% | 60% | 0 | 0 | 0 | 0 -小云小云(noisy) | 118/350 | 154/350 | 66.3% | 56% | 0 | 0 | 0 | 0 -你好问问(finetune with all keywords data) | 2236/10641 | 678/10641 | 79% | 93.6% | 0 | 0 | 0 | 0 -你好问问(finetune with only 你好问问) | 2236/10641 | 249/10641 | 79% | 97.7% | 0 | 0 | 0 | 0 diff --git a/egs/wenetspeech/KWS/prepare.sh b/egs/wenetspeech/KWS/prepare.sh deleted file mode 100755 index e52e1a9d1..000000000 --- a/egs/wenetspeech/KWS/prepare.sh +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -nj=15 -stage=0 -stop_stage=100 - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Prepare wewetspeech dataset." - mkdir -p data/fbank - if [ ! -e data/fbank/.wewetspeech.done ]; then - pushd ../ASR - ./prepare.sh --stage 0 --stop-stage 17 - ./prepare.sh --stage 22 --stop-stage 22 - popd - pushd data/fbank - ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) . - ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) . - ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) . - ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/L_split_1000) . - ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/M_split_1000) . - ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/S_split_1000) . - ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/musan_feats) . - popd - pushd data - ln -svf $(realpath ../ASR/data/lang_partial_tone) . - popd - touch data/fbank/.wewetspeech.done - else - log "WenetSpeech dataset already exists, skipping." - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare open commands dataset." - mkdir -p data/fbank - if [ ! -e data/fbank/.cn_speech_commands.done ]; then - pushd data - git clone https://github.com/pkufool/open-commands.git - ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt - ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt - pushd open-commands - ./scripts/prepare.sh --stage 1 --stop-stage 1 - ./scripts/prepare.sh --stage 3 --stop-stage 5 - popd - popd - pushd data/fbank - ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) . - ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) . - ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) . - ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) . - ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) . - ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) . - ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) . - popd - touch data/fbank/.cn_speech_commands.done - else - log "CN speech commands dataset already exists, skipping." - fi -fi diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh deleted file mode 100755 index 0af7c1595..000000000 --- a/egs/wenetspeech/KWS/run.sh +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -export CUDA_VISIBLE_DEVICES="0,1,2,3" -export PYTHONPATH=../../../:$PYTHONPATH - -stage=0 -stop_stage=100 - -. shared/parse_options.sh || exit 1 - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Train a model." - if [ ! -e data/fbank/.wenetspeech.done ]; then - log "You need to run the prepare.sh first." - exit -1 - fi - - python ./zipformer/train.py \ - --world-size 4 \ - --exp-dir zipformer/exp \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --num-epochs 18 \ - --lr-epochs 1.5 \ - --use-fp16 1 \ - --start-epoch 1 \ - --training-subset L \ - --pinyin-type partial_with_tone \ - --causal 1 \ - --lang-dir data/lang_partial_tone \ - --max-duration 1000 -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Decode the model." - export CUDA_VISIBLE_DEVICES="0" - for t in small large; do - python ./zipformer/decode.py \ - --epoch 18 \ - --avg 2 \ - --exp-dir ./zipformer/exp \ - --tokens ./data/lang_partial_tone/tokens.txt \ - --pinyin-type partial_with_tone \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --test-set $t \ - --keywords-score 1.5 \ - --keywords-threshold 0.1 \ - --keywords-file ./data/commands_${t}.txt \ - --max-duration 3000 - done -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Export the model." - - python ./zipformer/export.py \ - --epoch 18 \ - --avg 2 \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_partial_tone/tokens.txt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 - - python ./zipformer/export-onnx-streaming.py \ - --exp-dir zipformer/exp \ - --tokens data/lang_partial_tone/tokens.txt \ - --epoch 18 \ - --avg 2 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --causal 1 -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Finetune the model" - - # The following configuration of lr schedule should work well - # You may also tune the following parameters to adjust learning rate schedule - base_lr=0.0005 - lr_epochs=100 - lr_batches=100000 - - # We recommend to start from an averaged model - finetune_ckpt=zipformer/exp/pretrained.pt - - python ./zipformer/finetune.py \ - --world-size 4 \ - --num-epochs 10 \ - --start-epoch 1 \ - --exp-dir zipformer/exp_finetune \ - --lang-dir ./data/lang_partial_tone \ - --pinyin-type partial_with_tone \ - --use-fp16 1 \ - --use-mux 1 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --causal 1 \ - --base-lr $base_lr \ - --lr-epochs $lr_epochs \ - --lr-batches $lr_batches \ - --finetune-ckpt $finetune_ckpt \ - --max-duration 1500 -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Decode the finetuned model." - export CUDA_VISIBLE_DEVICES="0" - for t in small large; do - python ./zipformer/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./zipformer/exp_finetune \ - --tokens ./data/lang_partial_tone/tokens.txt \ - --pinyin-type partial_with_tone \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --test-set $t \ - --keywords-score 0.000001 \ - --keywords-threshold 0.35 \ - --keywords-file ./data/commands_${t}.txt \ - --max-duration 3000 - done -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Export the finetuned model." - - python ./zipformer/export.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./zipformer/exp_finetune \ - --tokens data/lang_partial_tone/tokens.txt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 64 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 - - python ./zipformer/export-onnx-streaming.py \ - --exp-dir zipformer/exp_finetune \ - --tokens data/lang_partial_tone/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoder-dim 320 \ - --joiner-dim 320 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 192,192,192,192,192,192 \ - --encoder-dim 128,128,128,128,128,128 \ - --encoder-unmasked-dim 128,128,128,128,128,128 \ - --causal 1 -fi diff --git a/egs/wenetspeech/KWS/shared b/egs/wenetspeech/KWS/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/wenetspeech/KWS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/asr_datamodule.py b/egs/wenetspeech/KWS/zipformer/asr_datamodule.py deleted file mode 100644 index 7de748c8e..000000000 --- a/egs/wenetspeech/KWS/zipformer/asr_datamodule.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - load_manifest, - load_manifest_lazy, - set_caching_enabled, -) -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class WenetSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--training-subset", - type=str, - default="L", - help="The training subset for using", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=300000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - - valid_dl = DataLoader( - validate, - batch_size=None, - sampler=valid_sampler, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") - - @lru_cache() - def test_net_cuts(self) -> List[CutSet]: - logging.info("About to get TEST_NET cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") - - @lru_cache() - def test_meeting_cuts(self) -> List[CutSet]: - logging.info("About to get TEST_MEETING cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") - - @lru_cache() - def cn_speech_commands_small_cuts(self) -> CutSet: - logging.info("About to get cn speech commands small cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cn_speech_commands_cuts_small.jsonl.gz" - ) - - @lru_cache() - def cn_speech_commands_large_cuts(self) -> CutSet: - logging.info("About to get cn speech commands large cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cn_speech_commands_cuts_large.jsonl.gz" - ) - - @lru_cache() - def nihaowenwen_dev_cuts(self) -> CutSet: - logging.info("About to get nihaowenwen dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "nihaowenwen_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def nihaowenwen_test_cuts(self) -> CutSet: - logging.info("About to get nihaowenwen test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "nihaowenwen_cuts_test.jsonl.gz" - ) - - @lru_cache() - def nihaowenwen_train_cuts(self) -> CutSet: - logging.info("About to get nihaowenwen train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "nihaowenwen_cuts_train.jsonl.gz" - ) - - @lru_cache() - def xiaoyun_clean_cuts(self) -> CutSet: - logging.info("About to get xiaoyun clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "xiaoyun_cuts_clean.jsonl.gz" - ) - - @lru_cache() - def xiaoyun_noisy_cuts(self) -> CutSet: - logging.info("About to get xiaoyun noisy cuts") - return load_manifest_lazy( - self.args.manifest_dir / "xiaoyun_cuts_noisy.jsonl.gz" - ) diff --git a/egs/wenetspeech/KWS/zipformer/beam_search.py b/egs/wenetspeech/KWS/zipformer/beam_search.py deleted file mode 120000 index 94033eebf..000000000 --- a/egs/wenetspeech/KWS/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/decode-asr.py b/egs/wenetspeech/KWS/zipformer/decode-asr.py deleted file mode 100755 index 6425030eb..000000000 --- a/egs/wenetspeech/KWS/zipformer/decode-asr.py +++ /dev/null @@ -1,767 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) fast beam search (LG) -./zipformer/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - blank_penalty=params.blank_penalty, - ilme_scale=params.ilme_scale, - ) - for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), - nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, - beam=params.beam_size, - ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - blank_penalty=params.blank_penalty, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) - - key = f"blank_penalty_{params.blank_penalty}" - if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}_" + key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - texts = [list("".join(text.split())) for text in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "modified_beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest_oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - dev_cuts = wenetspeech.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dls = [dev_dl, test_net_dl, test_meeting_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py deleted file mode 100755 index a628c7e58..000000000 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ /dev/null @@ -1,734 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import math -from collections import defaultdict -from dataclasses import dataclass, field -from pathlib import Path -from typing import Dict, List, Optional, Set, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from beam_search import keywords_search -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params - -from icefall import ContextGraph -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - num_tokens, - setup_logger, - store_transcripts, - str2bool, - text_to_pinyin, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -@dataclass -class KwMetric: - TP: int = 0 # True positive - FN: int = 0 # False negative - FP: int = 0 # False positive - TN: int = 0 # True negative - FN_list: List[str] = field(default_factory=list) - FP_list: List[str] = field(default_factory=list) - TP_list: List[str] = field(default_factory=list) - - def __str__(self) -> str: - return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})" - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/lang_partial_tone/tokens.txt", - help="The path to the token.txt", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - parser.add_argument( - "--pinyin-type", - type=str, - help="The type of pinyin used as the modeling units.", - ) - - parser.add_argument( - "--keywords-file", - type=str, - help="File contains keywords.", - ) - - parser.add_argument( - "--test-set", - type=str, - default="small", - help="small or large", - ) - - parser.add_argument( - "--keywords-score", - type=float, - default=1.5, - help=""" - The default boosting score (token level) for keywords. it will boost the - paths that match keywords to make them survive beam search. - """, - ) - - parser.add_argument( - "--keywords-threshold", - type=float, - default=0.35, - help="The default threshold (probability) to trigger the keyword.", - ) - - parser.add_argument( - "--num-tailing-blanks", - type=int, - default=1, - help="The number of tailing blanks should have after hitting one keyword.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, - keywords_graph: ContextGraph, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - ans_dict = keywords_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - keywords_graph=keywords_graph, - beam=params.beam_size, - num_tailing_blanks=8, - ) - - hyps = [] - for ans in ans_dict: - hyp = [] - for hit in ans: - hyp.append( - ( - hit.phrase, - (hit.timestamps[0], hit.timestamps[-1]), - ) - ) - hyps.append(hyp) - - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - keywords_graph: ContextGraph, - keywords: Set[str], - test_only_keywords: bool, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 20 - - results = [] - metric = {"all": KwMetric()} - for k in keywords: - metric[k] = KwMetric() - - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps = decode_one_batch( - params=params, - model=model, - keywords_graph=keywords_graph, - batch=batch, - ) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = list(ref_text) - hyp_words = [x[0] for x in hyp_words] - this_batch.append((cut_id, ref_words, list("".join(hyp_words)))) - hyp_set = set(hyp_words) - if len(hyp_words) > 1: - logging.warning( - f"Cut {cut_id} triggers more than one keywords : {hyp_words}," - f"please check the transcript to see if it really has more " - f"than one keywords, if so consider splitting this audio and" - f"keep only one keyword for each audio." - ) - hyp_str = " | ".join( - hyp_words - ) # The triggered keywords for this utterance. - TP = False - FP = False - for x in hyp_set: - assert x in keywords, x # can only trigger keywords - if (test_only_keywords and x == ref_text) or ( - not test_only_keywords and x in ref_text - ): - TP = True - metric[x].TP += 1 - metric[x].TP_list.append(f"({ref_text} -> {x})") - if (test_only_keywords and x != ref_text) or ( - not test_only_keywords and x not in ref_text - ): - FP = True - metric[x].FP += 1 - metric[x].FP_list.append(f"({ref_text} -> {x})") - if TP: - metric["all"].TP += 1 - if FP: - metric["all"].FP += 1 - TN = True # all keywords are true negative then the summery is true negative. - FN = False - for x in keywords: - if x not in ref_text and x not in hyp_set: - metric[x].TN += 1 - continue - - TN = False - if (test_only_keywords and x == ref_text) or ( - not test_only_keywords and x in ref_text - ): - fn = True - for y in hyp_set: - if (test_only_keywords and y == ref_text) or ( - not test_only_keywords and y in ref_text - ): - fn = False - break - if fn: - FN = True - metric[x].FN += 1 - metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") - if TN: - metric["all"].TN += 1 - if FN: - metric["all"].FN += 1 - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results, metric - - -def save_results( - params: AttributeDict, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], - metric: KwMetric, -): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" - - with open(metric_filename, "w") as of: - width = 10 - for key, item in sorted( - metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True - ): - acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) - precision = ( - 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP) - ) - recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN) - fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN) - s = f"{key}:\n" - s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" - s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" - s += f"\tAccuracy: {acc:.3f}\n" - s += f"\tPrecision: {precision:.3f}\n" - s += f"\tRecall(PPR): {recall:.3f}\n" - s += f"\tFPR: {fpr:.3f}\n" - s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" - if key != "all": - s += f"\tTP list: {' # '.join(item.TP_list)}\n" - s += f"\tFP list: {' # '.join(item.FP_list)}\n" - s += f"\tFN list: {' # '.join(item.FN_list)}\n" - of.write(s + "\n") - if key == "all": - logging.info(s) - of.write(f"\n\n{params.keywords_config}") - - logging.info("Wrote metric stats to {}".format(metric_filename)) - - -@torch.no_grad() -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "kws" - - params.suffix = params.test_set - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - params.suffix += f"-score-{params.keywords_score}" - params.suffix += f"-threshold-{params.keywords_threshold}" - params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" - if params.blank_penalty != 0: - params.suffix += f"-blank-penalty-{params.blank_penalty}" - params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - phrases = [] - token_ids = [] - keywords_scores = [] - keywords_thresholds = [] - keywords_config = [] - with open(params.keywords_file, "r") as f: - for line in f.readlines(): - keywords_config.append(line) - score = 0 - threshold = 0 - keyword = [] - words = line.strip().upper().split() - for word in words: - word = word.strip() - if word[0] == ":": - score = float(word[1:]) - continue - if word[0] == "#": - threshold = float(word[1:]) - continue - keyword.append(word) - keyword = "".join(keyword) - tmp_ids = [] - kws_py = text_to_pinyin(keyword, mode=params.pinyin_type) - for k in kws_py: - if k in token_table: - tmp_ids.append(token_table[k]) - else: - logging.warning(f"Containing OOV tokens, skipping line : {line}") - tmp_ids = [] - break - if tmp_ids: - logging.info(f"Adding keyword : {keyword}") - phrases.append(keyword) - token_ids.append(tmp_ids) - keywords_scores.append(score) - keywords_thresholds.append(threshold) - params.keywords_config = "".join(keywords_config) - - keywords_graph = ContextGraph( - context_score=params.keywords_score, ac_threshold=params.keywords_threshold - ) - keywords_graph.build( - token_ids=token_ids, - phrases=phrases, - scores=keywords_scores, - ac_thresholds=keywords_thresholds, - ) - keywords = set(phrases) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - cn_commands_small_cuts = wenetspeech.cn_speech_commands_small_cuts() - cn_commands_small_cuts = cn_commands_small_cuts.filter(remove_short_utt) - cn_commands_small_dl = wenetspeech.test_dataloaders(cn_commands_small_cuts) - - cn_commands_large_cuts = wenetspeech.cn_speech_commands_large_cuts() - cn_commands_large_cuts = cn_commands_large_cuts.filter(remove_short_utt) - cn_commands_large_dl = wenetspeech.test_dataloaders(cn_commands_large_cuts) - - nihaowenwen_test_cuts = wenetspeech.nihaowenwen_test_cuts() - nihaowenwen_test_cuts = nihaowenwen_test_cuts.filter(remove_short_utt) - nihaowenwen_test_dl = wenetspeech.test_dataloaders(nihaowenwen_test_cuts) - - xiaoyun_clean_cuts = wenetspeech.xiaoyun_clean_cuts() - xiaoyun_clean_cuts = xiaoyun_clean_cuts.filter(remove_short_utt) - xiaoyun_clean_dl = wenetspeech.test_dataloaders(xiaoyun_clean_cuts) - - xiaoyun_noisy_cuts = wenetspeech.xiaoyun_noisy_cuts() - xiaoyun_noisy_cuts = xiaoyun_noisy_cuts.filter(remove_short_utt) - xiaoyun_noisy_dl = wenetspeech.test_dataloaders(xiaoyun_noisy_cuts) - - test_sets = [] - test_dls = [] - if params.test_set == "large": - test_sets += ["cn_commands_large", "test_net"] - test_dls += [cn_commands_large_dl, test_net_dl] - else: - assert params.test_set == "small", params.test_set - test_sets += [ - "cn_commands_small", - "nihaowenwen", - "xiaoyun_clean", - "xiaoyun_noisy", - "test_net", - ] - test_dls += [ - cn_commands_small_dl, - nihaowenwen_test_dl, - xiaoyun_clean_dl, - xiaoyun_noisy_dl, - test_net_dl, - ] - - for test_set, test_dl in zip(test_sets, test_dls): - results, metric = decode_dataset( - dl=test_dl, - params=params, - model=model, - keywords_graph=keywords_graph, - keywords=keywords, - test_only_keywords="test_net" not in test_set, - ) - - save_results( - params=params, - test_set_name=test_set, - results=results, - metric=metric, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/KWS/zipformer/decoder.py b/egs/wenetspeech/KWS/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/wenetspeech/KWS/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/encoder_interface.py b/egs/wenetspeech/KWS/zipformer/encoder_interface.py deleted file mode 120000 index 2c56d3d18..000000000 --- a/egs/wenetspeech/KWS/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py b/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/export.py b/egs/wenetspeech/KWS/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/wenetspeech/KWS/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py deleted file mode 100755 index b1abfd79e..000000000 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ /dev/null @@ -1,764 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Yifan Yang, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model finetuning: -./zipformer/finetune.py \ - --world-size 4 \ - --num-epochs 10 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For non-streaming model finetuning with mux (original dataset): -./zipformer/finetune.py \ - --world-size 4 \ - --num-epochs 10 \ - --start-epoch 1 \ - --use-mux 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model finetuning: -./zipformer/fintune.py \ - --world-size 4 \ - --num-epochs 10 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -# For streaming model finetuning with mux (original dataset): -./zipformer/fintune.py \ - --world-size 4 \ - --num-epochs 10 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 -""" - - -import argparse -import copy -import logging -import warnings -from functools import partial -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from lhotse.cut import Cut, CutSet -from lhotse.utils import fix_random_seed -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from train import ( - add_model_arguments, - add_training_arguments, - compute_validation_loss, - display_and_save_batch, - encode_text, - get_adjusted_batch_count, - get_model, - get_params, - load_checkpoint_if_available, - save_checkpoint, - scan_pessimistic_batches_for_oom, - set_batch_count, -) - -from icefall import diagnostics -from icefall.checkpoint import remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - num_tokens, - setup_logger, - str2bool, - text_to_pinyin, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_finetune_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--use-mux", - type=str2bool, - default=False, - help=""" - Whether to adapt. If true, we will mix 5% of the new data - with 95% of the original data to fine-tune. - """, - ) - - parser.add_argument( - "--init-modules", - type=str, - default=None, - help=""" - Modules to be initialized. It matches all parameters starting with - a specific key. The keys are given with Comma seperated. If None, - all modules will be initialised. For example, if you only want to - initialise all parameters staring with "encoder", use "encoder"; - if you want to initialise parameters starting with encoder or decoder, - use "encoder,joiner". - """, - ) - - parser.add_argument( - "--finetune-ckpt", - type=str, - default=None, - help="Fine-tuning from which checkpoint (a path to a .pt file)", - ) - - parser.add_argument( - "--continue-finetune", - type=str2bool, - default=False, - help="Continue finetuning or finetune from pre-trained model", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_partial_tone", - help="Path to the pinyin lang directory", - ) - - parser.add_argument( - "--pinyin-type", - type=str, - default="partial_with_tone", - help=""" - The style of the output pinyin, should be: - full_with_tone : zhōng guó - full_no_tone : zhong guo - partial_with_tone : zh ōng g uó - partial_no_tone : zh ong g uo - """, - ) - - parser.add_argument( - "--pinyin-errors", - default="split", - type=str, - help="""How to handle characters that has no pinyin, - see `text_to_pinyin` in icefall/utils.py for details - """, - ) - - add_training_arguments(parser) - add_model_arguments(parser) - add_finetune_arguments(parser) - - return parser - - -def load_model_params( - ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True -): - """Load model params from checkpoint - - Args: - ckpt (str): Path to the checkpoint - model (nn.Module): model to be loaded - - """ - logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") - - # if module list is empty, load the whole model from ckpt - if not init_modules: - if next(iter(checkpoint["model"])).startswith("module."): - logging.info("Loading checkpoint saved by DDP") - - dst_state_dict = model.state_dict() - src_state_dict = checkpoint["model"] - for key in dst_state_dict.keys(): - src_key = "{}.{}".format("module", key) - dst_state_dict[key] = src_state_dict.pop(src_key) - assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=strict) - else: - model.load_state_dict(checkpoint["model"], strict=strict) - else: - src_state_dict = checkpoint["model"] - dst_state_dict = model.state_dict() - for module in init_modules: - logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [ - k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") - ] - dst_keys = [ - k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") - ] - assert set(src_keys) == set(dst_keys) # two sets should match exactly - for key in src_keys: - dst_state_dict[key] = src_state_dict.pop(key) - - model.load_state_dict(dst_state_dict, strict=strict) - - return None - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = [c.supervisions[0].tokens for c in supervisions["cut"]] - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss, ctc_loss = losses[:3] - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params) + 100000) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - - if params.continue_finetune: - assert params.start_epoch > 0, params.start_epoch - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - else: - modules = params.init_modules.split(",") if params.init_modules else None - checkpoints = load_model_params( - ckpt=params.finetune_ckpt, model=model, init_modules=modules - ) - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_utt(c: Cut): - if c.duration > 15: - return False - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - return T > 0 - - wenetspeech = WenetSpeechAsrDataModule(args) - - if params.use_mux: - train_cuts = CutSet.mux( - wenetspeech.train_cuts(), - wenetspeech.nihaowenwen_train_cuts(), - weights=[0.9, 0.1], - ) - else: - train_cuts = wenetspeech.nihaowenwen_train_cuts() - - _encode_text = partial(encode_text, token_table=token_table, params=params) - - train_cuts = train_cuts.filter(remove_short_utt) - train_cuts = train_cuts.map(_encode_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = wenetspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = wenetspeech.nihaowenwen_dev_cuts() - valid_cuts = valid_cuts.filter(remove_short_utt) - valid_cuts = valid_cuts.map(_encode_text) - valid_dl = wenetspeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics and params.scan_for_oom_batches: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.return_cuts = True - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/wenetspeech/KWS/zipformer/joiner.py b/egs/wenetspeech/KWS/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/wenetspeech/KWS/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/model.py b/egs/wenetspeech/KWS/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/wenetspeech/KWS/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/optim.py b/egs/wenetspeech/KWS/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/wenetspeech/KWS/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/scaling.py b/egs/wenetspeech/KWS/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/wenetspeech/KWS/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/scaling_converter.py b/egs/wenetspeech/KWS/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/wenetspeech/KWS/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/subsampling.py b/egs/wenetspeech/KWS/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/wenetspeech/KWS/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py deleted file mode 100755 index 5d9d8de36..000000000 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ /dev/null @@ -1,1391 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from functools import partial -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - num_tokens, - setup_logger, - str2bool, - text_to_pinyin, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="""Maximum left-contexts for causal training, measured in frames which will - be converted to a number of chunks. If splitting into chunks, - chunk left-context frames will be chosen randomly from this list; else not relevant.""", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def add_training_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="""The prune range for rnnt loss, it means how many symbols(context) - we are using to compute the loss""", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="""The scale to smooth the loss with lm - (output of prediction network) part.""", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="""The scale to smooth the loss with am (output of encoder network) part.""", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="""To get pruning ranges, we will calculate a simple version - loss(joiner is just addition), this simple loss also uses for - training (as a regularization item). We will scale the simple loss - with this parameter before adding to the final loss.""", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--scan-for-oom-batches", - type=str2bool, - default=False, - help=""" - Whether to scan for oom batches before training, this is helpful for - finding the suitable max_duration, you only need to run it once. - Caution: a little time consuming. - """, - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_partial_tone", - help="Path to the pinyin lang directory", - ) - - parser.add_argument( - "--pinyin-type", - type=str, - default="partial_with_tone", - help=""" - The style of the output pinyin, should be: - full_with_tone : zhōng guó - full_no_tone : zhong guo - partial_with_tone : zh ōng g uó - partial_no_tone : zh ong g uo - """, - ) - - parser.add_argument( - "--pinyin-errors", - default="split", - type=str, - help="""How to handle characters that has no pinyin, - see `text_to_pinyin` in icefall/utils.py for details - """, - ) - - add_training_arguments(parser) - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = [c.supervisions[0].tokens for c in supervisions["cut"]] - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - losses = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss, pruned_loss = losses[:2] - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict): - text = c.supervisions[0].text - tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) - ids = [] - for t in tokens: - if t in token_table: - ids.append(token_table[t]) - else: - logging.warning(f"Text : {text} has OOV token : {t} , encode to ") - ids.append(token_table[""]) - c.supervisions[0].tokens = ids - return c - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - wenetspeech = WenetSpeechAsrDataModule(args) - - train_cuts = wenetspeech.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 15.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - return True - - _encode_text = partial(encode_text, token_table=token_table, params=params) - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.map(_encode_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = wenetspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = wenetspeech.valid_cuts() - valid_cuts = valid_cuts.map(_encode_text) - valid_dl = wenetspeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics and params.scan_for_oom_batches: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - texts = supervisions["text"] - tokens = [c.supervisions[0].tokens for c in supervisions["cut"]] - num_tokens = sum(len(i) for i in tokens) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - args.return_cuts = True - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/wenetspeech/KWS/zipformer/zipformer.py b/egs/wenetspeech/KWS/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/wenetspeech/KWS/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md deleted file mode 100644 index 8329ae948..000000000 --- a/egs/wenetspeech4tts/TTS/README.md +++ /dev/null @@ -1,188 +0,0 @@ -# Results -| Model | Seed-TTS test_zh CER | Comment | -|---------------------------------------|---------------------|--------| -| [vall-e](./valle) | 4.33% | ~150M | -| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M | -| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M | - -# Introduction - -[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. - -> [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). -> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> -> By using this framework, you agree to the following: -> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> -> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> -> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> -> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. - - -# [VALL-E](https://arxiv.org/abs/2301.02111) - -./valle contains the code for training VALL-E TTS model. - -Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_wenetspeech4tts). The demo of the model trained with Wenetspeech4TTS Premium (945 hours) is available [here](https://huggingface.co/spaces/yuekai/valle_wenetspeech4tts_demo). - -Preparation: - -``` -bash prepare.sh -``` - -The training command is given below: - -``` -world_size=8 -exp_dir=exp/valle - -## Train AR model -python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ - --share-embedding true --norm-first true --add-prenet false \ - --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ - --base-lr 0.03 --warmup-steps 200 --average-period 0 \ - --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ - --exp-dir ${exp_dir} --world-size ${world_size} - -## Train NAR model -# cd ${exp_dir} -# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 -# cd - -python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ - --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ - --share-embedding true --norm-first true --add-prenet false \ - --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ - --base-lr 0.03 --warmup-steps 200 --average-period 0 \ - --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ - --exp-dir ${exp_dir} --world-size ${world_size} -``` - -To inference, use: -``` -huggingface-cli login -huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_wenetspeech4tts -top_p=1.0 -python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ - --top-k -1 --temperature 1.0 \ - --text ./aishell3.txt \ - --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ - --text-extractor pypinyin_initials_finals --top-p ${top_p} -``` - -# [F5-TTS](https://arxiv.org/abs/2410.06885) - -./f5-tts contains the code for training F5-TTS model. - -Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-small-wenetspeech4tts-basic/tensorboard). - -Preparation: - -``` -bash prepare.sh --stage 5 --stop_stage 6 -``` -(Note: To compatiable with F5-TTS official checkpoint, we direclty use `vocab.txt` from [here.](https://github.com/SWivid/F5-TTS/blob/129014c5b43f135b0100d49a0c6804dd4cf673e1/data/Emilia_ZH_EN_pinyin/vocab.txt) To generate your own `vocab.txt`, you may refer to [the script](https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/train/datasets/prepare_emilia.py).) - -The training command is given below: - -``` -# docker: ghcr.io/swivid/f5-tts:main -# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html -# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece - -world_size=8 -exp_dir=exp/f5-tts-small -python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ - --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ - --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ - --exp-dir ${exp_dir} --world-size ${world_size} -``` - -To inference with Icefall Wenetspeech4TTS trained F5-Small, use: -``` -huggingface-cli login -huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset -huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-small-wenetspeech4tts-basic -huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x - -manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst -model_path=f5-tts-small-wenetspeech4tts-basic/epoch-56-avg-14.pt -# skip -python3 f5-tts/generate_averaged_model.py \ - --epoch 56 \ - --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ - --exp-dir exp/f5_small - - -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 -bash local/compute_wer.sh $output_dir $manifest -``` - -To inference with official Emilia trained F5-Base, use: -``` -huggingface-cli login -huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset -huggingface-cli download --local-dir F5-TTS SWivid/F5-TTS -huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x - -manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst -model_path=./F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt - -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir -bash local/compute_wer.sh $output_dir $manifest -``` - -# F5-TTS-Semantic-Token - -./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens. During inference, we use the pretrained CosyVoice2 LLM to predict the semantic tokens for target audios. We observed that this approach leads to faster convergence and improved prosody modeling results. - -Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main). - -Preparation: - -``` -# extract cosyvoice2 semantic tokens -bash prepare.sh --stage 5 --stop_stage 7 -``` - -The training command is given below: - -``` -# docker: ghcr.io/swivid/f5-tts:main -# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html -# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece - -world_size=8 -exp_dir=exp/f5-tts-semantic-token-small -python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ - --base-lr 1e-4 --warmup-steps 20000 --average-period 0 \ - --num-epochs 10 --start-epoch 1 --start-batch 0 \ - --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ - --exp-dir ${exp_dir} --world-size ${world_size} \ - --decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True -``` - -To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use: -``` -huggingface-cli login -huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic -huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x - -split=test_zh -model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt - -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True -bash local/compute_wer.sh $output_dir $manifest -``` - -# Credits -- [VALL-E](https://github.com/lifeiteng/vall-e) -- [F5-TTS](https://github.com/SWivid/F5-TTS) -- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py b/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py deleted file mode 100644 index f02358553..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py +++ /dev/null @@ -1,173 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) -# Copyright 2024 Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) use the checkpoint exp_dir/epoch-xxx.pt -python3 bin/generate_averaged_model.py \ - --epoch 40 \ - --avg 5 \ - --exp-dir ${exp_dir} - -It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. -""" - - -import argparse -from pathlib import Path - -import k2 -import torch -from train import add_model_arguments, get_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, -) -from icefall.utils import AttributeDict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - add_model_arguments(parser) - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = AttributeDict() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = f"checkpoint-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - print("Script started") - - device = torch.device("cpu") - print(f"Device: {device}") - - print("About to create model") - filename = f"{params.exp_dir}/epoch-{params.epoch}.pt" - checkpoint = torch.load(filename, map_location=device) - args = AttributeDict(checkpoint) - model = get_model(args) - - if params.iter > 0: - # TODO FIX ME - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"checkpoint-{params.iter}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - filenames = [ - f"{params.exp_dir}/epoch-{i}.pt" for i in range(start, params.epoch + 1) - ] - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - checkpoint["model"] = model.state_dict() - torch.save(checkpoint, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py deleted file mode 100644 index 6964a43be..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ /dev/null @@ -1,828 +0,0 @@ -#!/usr/bin/env python3 -# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py -""" -Usage: -# docker: ghcr.io/swivid/f5-tts:main -# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html -# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx -# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x -manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst -python3 f5-tts/generate_averaged_model.py \ - --epoch 56 \ - --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ - --exp-dir exp/f5_small - -# command for text token input -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 - -# command for cosyvoice semantic token input -split=test_zh # seed_tts_eval test_zh -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True - -bash local/compute_wer.sh $output_dir $manifest -""" -import argparse -import logging -import math -import os -import random -import time -from pathlib import Path - -import datasets -import torch -import torch.nn.functional as F -import torchaudio -from accelerate import Accelerator -from bigvganinference import BigVGANInference -from model.cfm import CFM -from model.dit import DiT -from model.modules import MelSpec -from model.utils import convert_char_to_pinyin -from tqdm import tqdm -from train import ( - add_model_arguments, - get_model, - get_tokenizer, - interpolate_tokens, - load_F5_TTS_pretrained_checkpoint, -) - -from icefall.checkpoint import load_checkpoint -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tokens", - type=str, - default="f5-tts/vocab.txt", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--model-path", - type=str, - default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--seed", - type=int, - default=0, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--nfe", - type=int, - default=16, - help="The number of steps for the neural ODE", - ) - - parser.add_argument( - "--manifest-file", - type=str, - default=None, - help="The manifest file in seed_tts_eval format", - ) - - parser.add_argument( - "--output-dir", - type=Path, - default="results", - help="The output directory to save the generated wavs", - ) - - parser.add_argument("-ss", "--swaysampling", default=-1, type=float) - - parser.add_argument( - "--interpolate-token", - type=str2bool, - default=True, - help="Interpolate semantic token to match mel frames for CosyVoice", - ) - - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - - parser.add_argument( - "--split-name", - type=str, - default="wenetspeech4tts", - choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], - help="huggingface dataset split name", - ) - - add_model_arguments(parser) - return parser.parse_args() - - -def get_inference_prompt( - metainfo, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, -): - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( - metainfo, desc="Processing prompts..." - ): - # Audio - ref_audio, ref_sr = torchaudio.load(prompt_wav) - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) - if ref_rms < target_rms: - ref_audio = ref_audio * target_rms / ref_rms - assert ( - ref_audio.shape[-1] > 5000 - ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio) - - # Text - if len(prompt_text[-1].encode("utf-8")) == 1: - prompt_text = prompt_text + " " - text = [prompt_text + gt_text] - if tokenizer == "pinyin": - text_list = convert_char_to_pinyin(text, polyphone=polyphone) - else: - text_list = text - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - if use_truth_duration: - gt_audio, gt_sr = torchaudio.load(gt_wav) - if gt_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) - gt_audio = resampler(gt_audio) - total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) - - # # test vocoder resynthesis - # ref_audio = gt_audio - else: - ref_text_len = len(prompt_text.encode("utf-8")) - gen_text_len = len(gt_text.encode("utf-8")) - total_mel_len = ref_mel_len + int( - ref_mel_len / ref_text_len * gen_text_len / speed - ) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - final_text_list[bucket_i].extend(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def get_inference_prompt_cosy_voice_huggingface( - dataset, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, - interpolate_token=False, -): - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for i in range(len(dataset)): - utt = dataset[i]["id"] - ref_audio_org, ref_sr = ( - dataset[i]["prompt_audio"]["array"], - dataset[i]["prompt_audio"]["sampling_rate"], - ) - ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() - audio_tokens = dataset[i]["target_audio_cosy2_tokens"] - prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"] - - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) - if ref_rms < target_rms: - ref_audio_org = ref_audio_org * target_rms / ref_rms - - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio_org) - else: - ref_audio = ref_audio_org - input_tokens = prompt_audio_tokens + audio_tokens - - if interpolate_token: - input_tokens = interpolate_tokens(input_tokens) - text_list = input_tokens - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - - total_mel_len = len(input_tokens) - if not interpolate_token: - total_mel_len = int(total_mel_len / 4 * 15) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - if total_mel_len > max_tokens: - print( - f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - ) - continue - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - # final_text_list[bucket_i].extend(text_list) - final_text_list[bucket_i].append(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def inference_speech_token( - cosyvoice, - tts_text, - prompt_text, - prompt_speech_16k, - stream=False, - speed=1.0, - text_frontend=True, -): - tokens = [] - prompt_text = cosyvoice.frontend.text_normalize( - prompt_text, split=False, text_frontend=text_frontend - ) - for i in cosyvoice.frontend.text_normalize( - tts_text, split=True, text_frontend=text_frontend - ): - - tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i) - ( - prompt_text_token, - prompt_text_token_len, - ) = cosyvoice.frontend._extract_text_token(prompt_text) - speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token( - prompt_speech_16k - ) - - for i in cosyvoice.model.llm.inference( - text=tts_text_token.to(cosyvoice.model.device), - text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to( - cosyvoice.model.device - ), - prompt_text=prompt_text_token.to(cosyvoice.model.device), - prompt_text_len=torch.tensor( - [prompt_text_token.shape[1]], dtype=torch.int32 - ).to(cosyvoice.model.device), - prompt_speech_token=speech_token.to(cosyvoice.model.device), - prompt_speech_token_len=torch.tensor( - [speech_token.shape[1]], dtype=torch.int32 - ).to(cosyvoice.model.device), - embedding=None, - ): - tokens.append(i) - return tokens, speech_token - - -def get_inference_prompt_cosy_voice( - metainfo, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, - interpolate_token=False, -): - - import sys - - # please change the path to the cosyvoice accordingly - sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") - sys.path.append("/workspace/CosyVoice") - from cosyvoice.cli.cosyvoice import CosyVoice2 - - # please download the cosyvoice model first - cosyvoice = CosyVoice2( - "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False - ) - - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( - metainfo, desc="Processing prompts..." - ): - # Audio - ref_audio_org, ref_sr = torchaudio.load(prompt_wav) - - # cosy voice - if ref_sr != 16000: - resampler = torchaudio.transforms.Resample(ref_sr, 16000) - ref_audio_16k = resampler(ref_audio_org) - else: - ref_audio_16k = ref_audio_org - audio_tokens, prompt_audio_tokens = inference_speech_token( - cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False - ) - - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) - if ref_rms < target_rms: - ref_audio_org = ref_audio_org * target_rms / ref_rms - assert ( - ref_audio_org.shape[-1] > 5000 - ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio_org) - else: - ref_audio = ref_audio_org - - # Text - # if len(prompt_text[-1].encode("utf-8")) == 1: - # prompt_text = prompt_text + " " - # text = [prompt_text + gt_text] - # if tokenizer == "pinyin": - # text_list = convert_char_to_pinyin(text, polyphone=polyphone) - # else: - # text_list = text - - # concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens - # prompt_audio_tokens shape 1, prompt_audio_tokens - # audio_tokens shape 1, audio_tokens - prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist() - input_tokens = prompt_audio_tokens + audio_tokens - - # convert it into a list - # input_tokens_list = input_tokens.squeeze().cpu().tolist() - if interpolate_token: - input_tokens = interpolate_tokens(input_tokens) - text_list = input_tokens - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - if use_truth_duration: - gt_audio, gt_sr = torchaudio.load(gt_wav) - if gt_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) - gt_audio = resampler(gt_audio) - total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) - - # # test vocoder resynthesis - # ref_audio = gt_audio - else: - ref_text_len = len(prompt_text.encode("utf-8")) - gen_text_len = len(gt_text.encode("utf-8")) - total_mel_len_compute = ref_mel_len + int( - ref_mel_len / ref_text_len * gen_text_len / speed - ) - total_mel_len = len(input_tokens) - if not interpolate_token: - total_mel_len = int(total_mel_len / 4 * 15) - print( - f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" - ) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - # final_text_list[bucket_i].extend(text_list) - final_text_list[bucket_i].append(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def padded_mel_batch(ref_mels): - max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() - padded_ref_mels = [] - for mel in ref_mels: - padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) - padded_ref_mels.append(padded_ref_mel) - padded_ref_mels = torch.stack(padded_ref_mels) - padded_ref_mels = padded_ref_mels.permute(0, 2, 1) - return padded_ref_mels - - -def get_seedtts_testset_metainfo(metalst): - f = open(metalst) - lines = f.readlines() - f.close() - metainfo = [] - for line in lines: - assert len(line.strip().split("|")) == 4 - utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") - utt = Path(utt).stem - gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") - if not os.path.isabs(prompt_wav): - prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) - metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) - return metainfo - - -def main(): - args = get_parser() - - accelerator = Accelerator() - device = f"cuda:{accelerator.process_index}" - if args.manifest_file: - metainfo = get_seedtts_testset_metainfo(args.manifest_file) - if not args.use_cosyvoice_semantic_token: - prompts_all = get_inference_prompt( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - ) - else: - prompts_all = get_inference_prompt_cosy_voice( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - interpolate_token=args.interpolate_token, - ) - else: - assert args.use_cosyvoice_semantic_token - dataset = datasets.load_dataset( - "yuekai/seed_tts_cosy2", - split=args.split_name, - trust_remote_code=True, - ) - prompts_all = get_inference_prompt_cosy_voice_huggingface( - dataset, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - interpolate_token=args.interpolate_token, - ) - - vocoder = BigVGANInference.from_pretrained( - "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False - ) - vocoder = vocoder.eval().to(device) - - model = get_model(args).eval().to(device) - checkpoint = torch.load(args.model_path, map_location="cpu") - if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: - model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) - else: - _ = load_checkpoint( - args.model_path, - model=model, - ) - - os.makedirs(args.output_dir, exist_ok=True) - - accelerator.wait_for_everyone() - start = time.time() - - with accelerator.split_between_processes(prompts_all) as prompts: - for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): - ( - utts, - ref_rms_list, - ref_mels, - ref_mel_lens, - total_mel_lens, - final_text_list, - ) = prompt - ref_mels = ref_mels.to(device) - ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) - total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) - - if args.use_cosyvoice_semantic_token: - # concat final_text_list - max_len = max([len(tokens) for tokens in final_text_list]) - # pad tokens to the same length - for i, tokens in enumerate(final_text_list): - final_text_list[i] = torch.tensor( - tokens + [-1] * (max_len - len(tokens)), dtype=torch.long - ) - final_text_list = torch.stack(final_text_list).to(device) - - # Inference - with torch.inference_mode(): - generated, _ = model.sample( - cond=ref_mels, - text=final_text_list, - duration=total_mel_lens, - lens=ref_mel_lens, - steps=args.nfe, - cfg_strength=2.0, - sway_sampling_coef=args.swaysampling, - no_ref_audio=False, - seed=args.seed, - ) - for i, gen in enumerate(generated): - gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) - gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) - - generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() - target_rms = 0.1 - target_sample_rate = 24_000 - if ref_rms_list[i] < target_rms: - generated_wave = generated_wave * ref_rms_list[i] / target_rms - torchaudio.save( - f"{args.output_dir}/{utts[i]}.wav", - generated_wave, - target_sample_rate, - ) - - accelerator.wait_for_everyone() - if accelerator.is_main_process: - timediff = time.time() - start - print(f"Done batch inference in {timediff / 60 :.2f} minutes.") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/README.md b/egs/wenetspeech4tts/TTS/f5-tts/model/README.md deleted file mode 100644 index e4a7e2a7c..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Introduction -Files in this folder are copied from -https://github.com/SWivid/F5-TTS/tree/main/src/f5_tts/model diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py deleted file mode 100644 index 349c7220e..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -ein notation: -b - batch -n - sequence -nt - text sequence -nw - raw wave length -d - dimension -""" - -from __future__ import annotations - -from random import random -from typing import Callable - -import torch -import torch.nn.functional as F -from model.modules import MelSpec -from model.utils import ( - default, - exists, - lens_to_mask, - list_str_to_idx, - list_str_to_tensor, - mask_from_frac_lengths, -) -from torch import nn -from torch.nn.utils.rnn import pad_sequence -from torchdiffeq import odeint - - -class CFM(nn.Module): - def __init__( - self, - transformer: nn.Module, - sigma=0.0, - odeint_kwargs: dict = dict( - # atol = 1e-5, - # rtol = 1e-5, - method="euler" # 'midpoint' - ), - audio_drop_prob=0.3, - cond_drop_prob=0.2, - num_channels=None, - mel_spec_module: nn.Module | None = None, - mel_spec_kwargs: dict = dict(), - frac_lengths_mask: tuple[float, float] = (0.7, 1.0), - vocab_char_map: dict[str:int] | None = None, - ): - super().__init__() - - self.frac_lengths_mask = frac_lengths_mask - - # mel spec - self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) - num_channels = default(num_channels, self.mel_spec.n_mel_channels) - self.num_channels = num_channels - - # classifier-free guidance - self.audio_drop_prob = audio_drop_prob - self.cond_drop_prob = cond_drop_prob - - # transformer - self.transformer = transformer - dim = transformer.dim - self.dim = dim - - # conditional flow related - self.sigma = sigma - - # sampling related - self.odeint_kwargs = odeint_kwargs - - # vocab map for tokenization - self.vocab_char_map = vocab_char_map - - @property - def device(self): - return next(self.parameters()).device - - @torch.no_grad() - def sample( - self, - cond: float["b n d"] | float["b nw"], # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 - duration: int | int["b"], # noqa: F821 - *, - lens: int["b"] | None = None, # noqa: F821 - steps=32, - cfg_strength=1.0, - sway_sampling_coef=None, - seed: int | None = None, - max_duration=4096, - vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 - no_ref_audio=False, - duplicate_test=False, - t_inter=0.1, - edit_mask=None, - ): - self.eval() - # raw wave - - if cond.ndim == 2: - cond = self.mel_spec(cond) - cond = cond.permute(0, 2, 1) - assert cond.shape[-1] == self.num_channels - - cond = cond.to(next(self.parameters()).dtype) - - batch, cond_seq_len, device = *cond.shape[:2], cond.device - if not exists(lens): - lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) - - # text - - if isinstance(text, list): - if exists(self.vocab_char_map): - text = list_str_to_idx(text, self.vocab_char_map).to(device) - else: - text = list_str_to_tensor(text).to(device) - assert text.shape[0] == batch - - if exists(text): - text_lens = (text != -1).sum(dim=-1) - lens = torch.maximum( - text_lens, lens - ) # make sure lengths are at least those of the text characters - - # duration - - cond_mask = lens_to_mask(lens) - if edit_mask is not None: - cond_mask = cond_mask & edit_mask - - if isinstance(duration, int): - duration = torch.full((batch,), duration, device=device, dtype=torch.long) - - duration = torch.maximum( - lens + 1, duration - ) # just add one token so something is generated - duration = duration.clamp(max=max_duration) - max_duration = duration.amax() - - # duplicate test corner for inner time step oberservation - if duplicate_test: - test_cond = F.pad( - cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0 - ) - - cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) - cond_mask = F.pad( - cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False - ) - cond_mask = cond_mask.unsqueeze(-1) - step_cond = torch.where( - cond_mask, cond, torch.zeros_like(cond) - ) # allow direct control (cut cond audio) with lens passed in - - if batch > 1: - mask = lens_to_mask(duration) - else: # save memory and speed up, as single inference need no mask currently - mask = None - - # test for no ref audio - if no_ref_audio: - cond = torch.zeros_like(cond) - - # neural ode - - def fn(t, x): - # at each step, conditioning is fixed - # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - - # predict flow - pred = self.transformer( - x=x, - cond=step_cond, - text=text, - time=t, - mask=mask, - drop_audio_cond=False, - drop_text=False, - ) - if cfg_strength < 1e-5: - return pred - - null_pred = self.transformer( - x=x, - cond=step_cond, - text=text, - time=t, - mask=mask, - drop_audio_cond=True, - drop_text=True, - ) - return pred + (pred - null_pred) * cfg_strength - - # noise input - # to make sure batch inference result is same with different batch size, and for sure single inference - # still some difference maybe due to convolutional layers - y0 = [] - for dur in duration: - if exists(seed): - torch.manual_seed(seed) - y0.append( - torch.randn( - dur, self.num_channels, device=self.device, dtype=step_cond.dtype - ) - ) - y0 = pad_sequence(y0, padding_value=0, batch_first=True) - - t_start = 0 - - # duplicate test corner for inner time step oberservation - if duplicate_test: - t_start = t_inter - y0 = (1 - t_start) * y0 + t_start * test_cond - steps = int(steps * (1 - t_start)) - - t = torch.linspace( - t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype - ) - if sway_sampling_coef is not None: - t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) - - trajectory = odeint(fn, y0, t, **self.odeint_kwargs) - - sampled = trajectory[-1] - out = sampled - out = torch.where(cond_mask, cond, out) - - if exists(vocoder): - out = out.permute(0, 2, 1) - out = vocoder(out) - - return out, trajectory - - def forward( - self, - inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 - *, - lens: int["b"] | None = None, # noqa: F821 - noise_scheduler: str | None = None, - ): - # handle raw wave - if inp.ndim == 2: - inp = self.mel_spec(inp) - inp = inp.permute(0, 2, 1) - assert inp.shape[-1] == self.num_channels - - batch, seq_len, dtype, device, _σ1 = ( - *inp.shape[:2], - inp.dtype, - self.device, - self.sigma, - ) - - # handle text as string - if isinstance(text, list): - if exists(self.vocab_char_map): - text = list_str_to_idx(text, self.vocab_char_map).to(device) - else: - text = list_str_to_tensor(text).to(device) - assert text.shape[0] == batch - - # lens and mask - if not exists(lens): - lens = torch.full((batch,), seq_len, device=device) - - mask = lens_to_mask( - lens, length=seq_len - ) # useless here, as collate_fn will pad to max length in batch - - # get a random span to mask out for training conditionally - frac_lengths = ( - torch.zeros((batch,), device=self.device) - .float() - .uniform_(*self.frac_lengths_mask) - ) - rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) - - if exists(mask): - rand_span_mask &= mask - - # mel is x1 - x1 = inp - - # x0 is gaussian noise - x0 = torch.randn_like(x1) - - # time step - time = torch.rand((batch,), dtype=dtype, device=self.device) - # TODO. noise_scheduler - - # sample xt (φ_t(x) in the paper) - t = time.unsqueeze(-1).unsqueeze(-1) - φ = (1 - t) * x0 + t * x1 - flow = x1 - x0 - - # only predict what is within the random mask span for infilling - cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - - # transformer and cfg training with a drop rate - drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper - if random() < self.cond_drop_prob: # p_uncond in voicebox paper - drop_audio_cond = True - drop_text = True - else: - drop_text = False - - # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here - # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences - pred = self.transformer( - x=φ, - cond=cond, - text=text, - time=time, - drop_audio_cond=drop_audio_cond, - drop_text=drop_text, - ) - - # flow matching loss - loss = F.mse_loss(pred, flow, reduction="none") - loss = loss[rand_span_mask] - - return loss.mean(), cond, pred diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py deleted file mode 100644 index 966fabfdd..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -ein notation: -b - batch -n - sequence -nt - text sequence -nw - raw wave length -d - dimension -""" - -from __future__ import annotations - -import torch -import torch.nn.functional as F -from model.modules import ( - AdaLayerNormZero_Final, - ConvNeXtV2Block, - ConvPositionEmbedding, - DiTBlock, - TimestepEmbedding, - get_pos_embed_indices, - precompute_freqs_cis, -) -from torch import nn -from x_transformers.x_transformers import RotaryEmbedding - -# Text embedding - - -class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): - super().__init__() - self.text_embed = nn.Embedding( - text_num_embeds + 1, text_dim - ) # use 0 as filler token - - if conv_layers > 0: - self.extra_modeling = True - self.precompute_max_pos = 4096 # ~44s of 24khz audio - self.register_buffer( - "freqs_cis", - precompute_freqs_cis(text_dim, self.precompute_max_pos), - persistent=False, - ) - self.text_blocks = nn.Sequential( - *[ - ConvNeXtV2Block(text_dim, text_dim * conv_mult) - for _ in range(conv_layers) - ] - ) - else: - self.extra_modeling = False - - def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - text = ( - text + 1 - ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[ - :, :seq_len - ] # curtail if character tokens are more than the mel spec tokens - batch, text_len = text.shape[0], text.shape[1] - text = F.pad(text, (0, seq_len - text_len), value=0) - - if drop_text: # cfg for text - text = torch.zeros_like(text) - - text = self.text_embed(text) # b n -> b n d - - # possible extra modeling - if self.extra_modeling: - # sinus pos emb - batch_start = torch.zeros((batch,), dtype=torch.long) - pos_idx = get_pos_embed_indices( - batch_start, seq_len, max_pos=self.precompute_max_pos - ) - text_pos_embed = self.freqs_cis[pos_idx] - text = text + text_pos_embed - - # convnextv2 blocks - text = self.text_blocks(text) - - return text - - -# noised input audio and context mixing embedding - - -class InputEmbedding(nn.Module): - def __init__(self, mel_dim, text_dim, out_dim): - super().__init__() - self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) - self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - - def forward( - self, - x: float["b n d"], # noqa: F722 - cond: float["b n d"], # noqa: F722 - text_embed: float["b n d"], # noqa: F722 - drop_audio_cond=False, - ): - if drop_audio_cond: # cfg for cond audio - cond = torch.zeros_like(cond) - - x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) - x = self.conv_pos_embed(x) + x - return x - - -# Transformer backbone using DiT blocks - - -class DiT(nn.Module): - def __init__( - self, - *, - dim, - depth=8, - heads=8, - dim_head=64, - dropout=0.1, - ff_mult=4, - mel_dim=100, - text_num_embeds=256, - text_dim=None, - conv_layers=0, - long_skip_connection=False, - checkpoint_activations=False, - ): - super().__init__() - - self.time_embed = TimestepEmbedding(dim) - if text_dim is None: - text_dim = mel_dim - self.text_embed = TextEmbedding( - text_num_embeds, text_dim, conv_layers=conv_layers - ) - self.input_embed = InputEmbedding(mel_dim, text_dim, dim) - - self.rotary_embed = RotaryEmbedding(dim_head) - - self.dim = dim - self.depth = depth - - self.transformer_blocks = nn.ModuleList( - [ - DiTBlock( - dim=dim, - heads=heads, - dim_head=dim_head, - ff_mult=ff_mult, - dropout=dropout, - ) - for _ in range(depth) - ] - ) - self.long_skip_connection = ( - nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None - ) - - self.norm_out = AdaLayerNormZero_Final(dim) # final modulation - self.proj_out = nn.Linear(dim, mel_dim) - - self.checkpoint_activations = checkpoint_activations - - def ckpt_wrapper(self, module): - # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py - def ckpt_forward(*inputs): - outputs = module(*inputs) - return outputs - - return ckpt_forward - - def forward( - self, - x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 - time: float["b"] | float[""], # time step # noqa: F821 F722 - drop_audio_cond, # cfg for cond audio - drop_text, # cfg for text - mask: bool["b n"] | None = None, # noqa: F722 - ): - batch, seq_len = x.shape[0], x.shape[1] - if time.ndim == 0: - time = time.repeat(batch) - - # t: conditioning time, c: context (text + masked cond audio), x: noised input audio - t = self.time_embed(time) - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) - x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) - - rope = self.rotary_embed.forward_from_seq_len(seq_len) - - if self.long_skip_connection is not None: - residual = x - - for block in self.transformer_blocks: - if self.checkpoint_activations: - x = torch.utils.checkpoint.checkpoint( - self.ckpt_wrapper(block), x, t, mask, rope - ) - else: - x = block(x, t, mask=mask, rope=rope) - - if self.long_skip_connection is not None: - x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) - - x = self.norm_out(x, t) - output = self.proj_out(x) - - return output diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py b/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py deleted file mode 100644 index 05299d419..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py +++ /dev/null @@ -1,728 +0,0 @@ -""" -ein notation: -b - batch -n - sequence -nt - text sequence -nw - raw wave length -d - dimension -""" - -from __future__ import annotations - -import math -from typing import Optional - -import torch -import torch.nn.functional as F -import torchaudio -from librosa.filters import mel as librosa_mel_fn -from torch import nn -from x_transformers.x_transformers import apply_rotary_pos_emb - -# raw wav to mel spec - - -mel_basis_cache = {} -hann_window_cache = {} - - -def get_bigvgan_mel_spectrogram( - waveform, - n_fft=1024, - n_mel_channels=100, - target_sample_rate=24000, - hop_length=256, - win_length=1024, - fmin=0, - fmax=None, - center=False, -): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main - device = waveform.device - key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" - - if key not in mel_basis_cache: - mel = librosa_mel_fn( - sr=target_sample_rate, - n_fft=n_fft, - n_mels=n_mel_channels, - fmin=fmin, - fmax=fmax, - ) - mel_basis_cache[key] = ( - torch.from_numpy(mel).float().to(device) - ) # TODO: why they need .float()? - hann_window_cache[key] = torch.hann_window(win_length).to(device) - - mel_basis = mel_basis_cache[key] - hann_window = hann_window_cache[key] - - padding = (n_fft - hop_length) // 2 - waveform = torch.nn.functional.pad( - waveform.unsqueeze(1), (padding, padding), mode="reflect" - ).squeeze(1) - - spec = torch.stft( - waveform, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window, - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) - - mel_spec = torch.matmul(mel_basis, spec) - mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) - - return mel_spec - - -def get_vocos_mel_spectrogram( - waveform, - n_fft=1024, - n_mel_channels=100, - target_sample_rate=24000, - hop_length=256, - win_length=1024, -): - mel_stft = torchaudio.transforms.MelSpectrogram( - sample_rate=target_sample_rate, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - n_mels=n_mel_channels, - power=1, - center=True, - normalized=False, - norm=None, - ).to(waveform.device) - if len(waveform.shape) == 3: - waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' - - assert len(waveform.shape) == 2 - - mel = mel_stft(waveform) - mel = mel.clamp(min=1e-5).log() - return mel - - -class MelSpec(nn.Module): - def __init__( - self, - n_fft=1024, - hop_length=256, - win_length=1024, - n_mel_channels=100, - target_sample_rate=24_000, - mel_spec_type="vocos", - ): - super().__init__() - assert mel_spec_type in ["vocos", "bigvgan"], print( - "We only support two extract mel backend: vocos or bigvgan" - ) - - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.n_mel_channels = n_mel_channels - self.target_sample_rate = target_sample_rate - - if mel_spec_type == "vocos": - self.extractor = get_vocos_mel_spectrogram - elif mel_spec_type == "bigvgan": - self.extractor = get_bigvgan_mel_spectrogram - - self.register_buffer("dummy", torch.tensor(0), persistent=False) - - def forward(self, wav): - if self.dummy.device != wav.device: - self.to(wav.device) - - mel = self.extractor( - waveform=wav, - n_fft=self.n_fft, - n_mel_channels=self.n_mel_channels, - target_sample_rate=self.target_sample_rate, - hop_length=self.hop_length, - win_length=self.win_length, - ) - - return mel - - -# sinusoidal position embedding - - -class SinusPositionEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x, scale=1000): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -# convolutional position embedding - - -class ConvPositionEmbedding(nn.Module): - def __init__(self, dim, kernel_size=31, groups=16): - super().__init__() - assert kernel_size % 2 != 0 - self.conv1d = nn.Sequential( - nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), - nn.Mish(), - nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), - nn.Mish(), - ) - - def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 - if mask is not None: - mask = mask[..., None] - x = x.masked_fill(~mask, 0.0) - - x = x.permute(0, 2, 1) - x = self.conv1d(x) - out = x.permute(0, 2, 1) - - if mask is not None: - out = out.masked_fill(~mask, 0.0) - - return out - - -# rotary positional embedding related - - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 -): - # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning - # has some connection to NTK literature - # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py - theta *= theta_rescale_factor ** (dim / (dim - 2)) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cos = torch.cos(freqs) # real part - freqs_sin = torch.sin(freqs) # imaginary part - return torch.cat([freqs_cos, freqs_sin], dim=-1) - - -def get_pos_embed_indices(start, length, max_pos, scale=1.0): - # length = length if isinstance(length, int) else length.max() - scale = scale * torch.ones_like( - start, dtype=torch.float32 - ) # in case scale is a scalar - pos = ( - start.unsqueeze(1) - + ( - torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) - * scale.unsqueeze(1) - ).long() - ) - # avoid extra long error. - pos = torch.where(pos < max_pos, pos, max_pos - 1) - return pos - - -# Global Response Normalization layer (Instance Normalization ?) - - -class GRN(nn.Module): - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=1, keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - - -# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py -# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 - - -class ConvNeXtV2Block(nn.Module): - def __init__( - self, - dim: int, - intermediate_dim: int, - dilation: int = 1, - ): - super().__init__() - padding = (dilation * (7 - 1)) // 2 - self.dwconv = nn.Conv1d( - dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation - ) # depthwise conv - self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear( - dim, intermediate_dim - ) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.grn = GRN(intermediate_dim) - self.pwconv2 = nn.Linear(intermediate_dim, dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = x.transpose(1, 2) # b n d -> b d n - x = self.dwconv(x) - x = x.transpose(1, 2) # b d n -> b n d - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x) - x = self.pwconv2(x) - return residual + x - - -# AdaLayerNormZero -# return with modulated x for attn input, and params for later mlp modulation - - -class AdaLayerNormZero(nn.Module): - def __init__(self, dim): - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(dim, dim * 6) - - self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - - def forward(self, x, emb=None): - emb = self.linear(self.silu(emb)) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk( - emb, 6, dim=1 - ) - - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp - - -# AdaLayerNormZero for final layer -# return only with modulated x for attn input, cuz no more mlp modulation - - -class AdaLayerNormZero_Final(nn.Module): - def __init__(self, dim): - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(dim, dim * 2) - - self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - - def forward(self, x, emb): - emb = self.linear(self.silu(emb)) - scale, shift = torch.chunk(emb, 2, dim=1) - - x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] - return x - - -# FeedForward - - -class FeedForward(nn.Module): - def __init__( - self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - activation = nn.GELU(approximate=approximate) - project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) - self.ff = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.ff(x) - - -# Attention with possible joint part -# modified from diffusers/src/diffusers/models/attention_processor.py - - -class Attention(nn.Module): - def __init__( - self, - processor: JointAttnProcessor | AttnProcessor, - dim: int, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - context_dim: Optional[int] = None, # if not None -> joint attention - context_pre_only=None, - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.processor = processor - - self.dim = dim - self.heads = heads - self.inner_dim = dim_head * heads - self.dropout = dropout - - self.context_dim = context_dim - self.context_pre_only = context_pre_only - - self.to_q = nn.Linear(dim, self.inner_dim) - self.to_k = nn.Linear(dim, self.inner_dim) - self.to_v = nn.Linear(dim, self.inner_dim) - - if self.context_dim is not None: - self.to_k_c = nn.Linear(context_dim, self.inner_dim) - self.to_v_c = nn.Linear(context_dim, self.inner_dim) - if self.context_pre_only is not None: - self.to_q_c = nn.Linear(context_dim, self.inner_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, dim)) - self.to_out.append(nn.Dropout(dropout)) - - if self.context_pre_only is not None and not self.context_pre_only: - self.to_out_c = nn.Linear(self.inner_dim, dim) - - def forward( - self, - x: float["b n d"], # noised input x # noqa: F722 - c: float["b n d"] = None, # context c # noqa: F722 - mask: bool["b n"] | None = None, # noqa: F722 - rope=None, # rotary position embedding for x - c_rope=None, # rotary position embedding for c - ) -> torch.Tensor: - if c is not None: - return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) - else: - return self.processor(self, x, mask=mask, rope=rope) - - -# Attention processor - - -class AttnProcessor: - def __init__(self): - pass - - def __call__( - self, - attn: Attention, - x: float["b n d"], # noised input x # noqa: F722 - mask: bool["b n"] | None = None, # noqa: F722 - rope=None, # rotary position embedding - ) -> torch.FloatTensor: - batch_size = x.shape[0] - - # `sample` projections. - query = attn.to_q(x) - key = attn.to_k(x) - value = attn.to_v(x) - - # apply rotary position embedding - if rope is not None: - freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = ( - (xpos_scale, xpos_scale**-1.0) - if xpos_scale is not None - else (1.0, 1.0) - ) - - query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) - key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) - - # attention - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # mask. e.g. inference got a batch with different target durations, mask out the padding - if mask is not None: - attn_mask = mask - attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand( - batch_size, attn.heads, query.shape[-2], key.shape[-2] - ) - else: - attn_mask = None - - x = F.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False - ) - x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - x = x.to(query.dtype) - - # linear proj - x = attn.to_out[0](x) - # dropout - x = attn.to_out[1](x) - - if mask is not None: - mask = mask.unsqueeze(-1) - x = x.masked_fill(~mask, 0.0) - - return x - - -# Joint Attention processor for MM-DiT -# modified from diffusers/src/diffusers/models/attention_processor.py - - -class JointAttnProcessor: - def __init__(self): - pass - - def __call__( - self, - attn: Attention, - x: float["b n d"], # noised input x # noqa: F722 - c: float["b nt d"] = None, # context c, here text # noqa: F722 - mask: bool["b n"] | None = None, # noqa: F722 - rope=None, # rotary position embedding for x - c_rope=None, # rotary position embedding for c - ) -> torch.FloatTensor: - residual = x - - batch_size = c.shape[0] - - # `sample` projections. - query = attn.to_q(x) - key = attn.to_k(x) - value = attn.to_v(x) - - # `context` projections. - c_query = attn.to_q_c(c) - c_key = attn.to_k_c(c) - c_value = attn.to_v_c(c) - - # apply rope for context and noised input independently - if rope is not None: - freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = ( - (xpos_scale, xpos_scale**-1.0) - if xpos_scale is not None - else (1.0, 1.0) - ) - query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) - key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) - if c_rope is not None: - freqs, xpos_scale = c_rope - q_xpos_scale, k_xpos_scale = ( - (xpos_scale, xpos_scale**-1.0) - if xpos_scale is not None - else (1.0, 1.0) - ) - c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) - c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) - - # attention - query = torch.cat([query, c_query], dim=1) - key = torch.cat([key, c_key], dim=1) - value = torch.cat([value, c_value], dim=1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # mask. e.g. inference got a batch with different target durations, mask out the padding - if mask is not None: - attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) - attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand( - batch_size, attn.heads, query.shape[-2], key.shape[-2] - ) - else: - attn_mask = None - - x = F.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False - ) - x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - x = x.to(query.dtype) - - # Split the attention outputs. - x, c = ( - x[:, : residual.shape[1]], - x[:, residual.shape[1] :], - ) - - # linear proj - x = attn.to_out[0](x) - # dropout - x = attn.to_out[1](x) - if not attn.context_pre_only: - c = attn.to_out_c(c) - - if mask is not None: - mask = mask.unsqueeze(-1) - x = x.masked_fill(~mask, 0.0) - # c = c.masked_fill(~mask, 0.) # no mask for c (text) - - return x, c - - -# DiT Block - - -class DiTBlock(nn.Module): - def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): - super().__init__() - - self.attn_norm = AdaLayerNormZero(dim) - self.attn = Attention( - processor=AttnProcessor(), - dim=dim, - heads=heads, - dim_head=dim_head, - dropout=dropout, - ) - - self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward( - dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" - ) - - def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding - # pre-norm & modulation for attention input - norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) - - # attention - attn_output = self.attn(x=norm, mask=mask, rope=rope) - - # process attention output for input x - x = x + gate_msa.unsqueeze(1) * attn_output - - norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(norm) - x = x + gate_mlp.unsqueeze(1) * ff_output - - return x - - -# MMDiT Block https://arxiv.org/abs/2403.03206 - - -class MMDiTBlock(nn.Module): - r""" - modified from diffusers/src/diffusers/models/attention.py - - notes. - _c: context related. text, cond, etc. (left part in sd3 fig2.b) - _x: noised input related. (right part) - context_pre_only: last layer only do prenorm + modulation cuz no more ffn - """ - - def __init__( - self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False - ): - super().__init__() - - self.context_pre_only = context_pre_only - - self.attn_norm_c = ( - AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) - ) - self.attn_norm_x = AdaLayerNormZero(dim) - self.attn = Attention( - processor=JointAttnProcessor(), - dim=dim, - heads=heads, - dim_head=dim_head, - dropout=dropout, - context_dim=dim, - context_pre_only=context_pre_only, - ) - - if not context_pre_only: - self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_c = FeedForward( - dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" - ) - else: - self.ff_norm_c = None - self.ff_c = None - self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_x = FeedForward( - dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" - ) - - def forward( - self, x, c, t, mask=None, rope=None, c_rope=None - ): # x: noised input, c: context, t: time embedding - # pre-norm & modulation for attention input - if self.context_pre_only: - norm_c = self.attn_norm_c(c, t) - else: - norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c( - c, emb=t - ) - norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x( - x, emb=t - ) - - # attention - x_attn_output, c_attn_output = self.attn( - x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope - ) - - # process attention output for context c - if self.context_pre_only: - c = None - else: # if not last layer - c = c + c_gate_msa.unsqueeze(1) * c_attn_output - - norm_c = ( - self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - ) - c_ff_output = self.ff_c(norm_c) - c = c + c_gate_mlp.unsqueeze(1) * c_ff_output - - # process attention output for input x - x = x + x_gate_msa.unsqueeze(1) * x_attn_output - - norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] - x_ff_output = self.ff_x(norm_x) - x = x + x_gate_mlp.unsqueeze(1) * x_ff_output - - return c, x - - -# time step conditioning embedding - - -class TimestepEmbedding(nn.Module): - def __init__(self, dim, freq_embed_dim=256): - super().__init__() - self.time_embed = SinusPositionEmbedding(freq_embed_dim) - self.time_mlp = nn.Sequential( - nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) - ) - - def forward(self, timestep: float["b"]): # noqa: F821 - time_hidden = self.time_embed(timestep) - time_hidden = time_hidden.to(timestep.dtype) - time = self.time_mlp(time_hidden) # b d - return time diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py deleted file mode 100644 index fae5fadb6..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py +++ /dev/null @@ -1,206 +0,0 @@ -from __future__ import annotations - -import os -import random -from collections import defaultdict -from importlib.resources import files - -import jieba -import torch -from pypinyin import Style, lazy_pinyin -from torch.nn.utils.rnn import pad_sequence - -# seed everything - - -def seed_everything(seed=0): - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -# helpers - - -def exists(v): - return v is not None - - -def default(v, d): - return v if exists(v) else d - - -# tensor helpers - - -def lens_to_mask( - t: int["b"], length: int | None = None # noqa: F722 F821 -) -> bool["b n"]: # noqa: F722 F821 - if not exists(length): - length = t.amax() - - seq = torch.arange(length, device=t.device) - return seq[None, :] < t[:, None] - - -def mask_from_start_end_indices( - seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821 -): - max_seq_len = seq_len.max().item() - seq = torch.arange(max_seq_len, device=start.device).long() - start_mask = seq[None, :] >= start[:, None] - end_mask = seq[None, :] < end[:, None] - return start_mask & end_mask - - -def mask_from_frac_lengths( - seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821 -): - lengths = (frac_lengths * seq_len).long() - max_start = seq_len - lengths - - rand = torch.rand_like(frac_lengths) - start = (max_start * rand).long().clamp(min=0) - end = start + lengths - - return mask_from_start_end_indices(seq_len, start, end) - - -def maybe_masked_mean( - t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821 -) -> float["b d"]: # noqa: F722 F821 - if not exists(mask): - return t.mean(dim=1) - - t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) - num = t.sum(dim=1) - den = mask.float().sum(dim=1) - - return num / den.clamp(min=1.0) - - -# simple utf-8 tokenizer, since paper went character based -def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 - list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style - text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) - return text - - -# char tokenizer, based on custom dataset's extracted .txt file -def list_str_to_idx( - text: list[str] | list[list[str]], - vocab_char_map: dict[str, int], # {char: idx} - padding_value=-1, -) -> int["b nt"]: # noqa: F722 - list_idx_tensors = [ - torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text - ] # pinyin or char style - text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) - return text - - -# Get tokenizer - - -def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - """ - tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - - "char" for char-wise tokenizer, need .txt vocab_file - - "byte" for utf-8 tokenizer - - "custom" if you're directly passing in a path to the vocab.txt you want to use - vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - - if use "char", derived from unfiltered character & symbol counts of custom dataset - - if use "byte", set to 256 (unicode byte range) - """ - if tokenizer in ["pinyin", "char"]: - tokenizer_path = os.path.join( - files("f5_tts").joinpath("../../data"), - f"{dataset_name}_{tokenizer}/vocab.txt", - ) - with open(tokenizer_path, "r", encoding="utf-8") as f: - vocab_char_map = {} - for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i - vocab_size = len(vocab_char_map) - assert ( - vocab_char_map[" "] == 0 - ), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" - - elif tokenizer == "byte": - vocab_char_map = None - vocab_size = 256 - - elif tokenizer == "custom": - with open(dataset_name, "r", encoding="utf-8") as f: - vocab_char_map = {} - for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i - vocab_size = len(vocab_char_map) - - return vocab_char_map, vocab_size - - -# convert char to pinyin - -jieba.initialize() -print("Word segmentation module jieba initialized.\n") - - -def convert_char_to_pinyin(text_list, polyphone=True): - final_text_list = [] - custom_trans = str.maketrans( - {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} - ) # add custom trans here, to address oov - - def is_chinese(c): - return "\u3100" <= c <= "\u9fff" # common chinese characters - - for text in text_list: - char_list = [] - text = text.translate(custom_trans) - for seg in jieba.cut(text): - seg_byte_len = len(bytes(seg, "UTF-8")) - if seg_byte_len == len(seg): # if pure alphabets and symbols - if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": - char_list.append(" ") - char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len( - seg - ): # if pure east asian characters - seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) - for i, c in enumerate(seg): - if is_chinese(c): - char_list.append(" ") - char_list.append(seg_[i]) - else: # if mixed characters, alphabets and symbols - for c in seg: - if ord(c) < 256: - char_list.extend(c) - elif is_chinese(c): - char_list.append(" ") - char_list.extend( - lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) - ) - else: - char_list.append(c) - final_text_list.append(char_list) - - return final_text_list - - -# filter func for dirty data with many repetitions - - -def repetition_found(text, length=2, tolerance=10): - pattern_count = defaultdict(int) - for i in range(len(text) - length + 1): - pattern = text[i : i + length] - pattern_count[pattern] += 1 - for pattern, count in pattern_count.items(): - if count > tolerance: - return True - return False diff --git a/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt b/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt deleted file mode 100644 index 63f1e237c..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt +++ /dev/null @@ -1,36 +0,0 @@ -# F5-TTS -accelerate>=0.33.0 -bitsandbytes>0.37.0 -cached_path -click -datasets -ema_pytorch>=0.5.2 -gradio>=3.45.2 -hydra-core>=1.3.0 -jieba -librosa -matplotlib -numpy<=1.26.4 -pydub -pypinyin -safetensors -soundfile -tomli -torch>=2.0.0 -torchaudio>=2.0.0 -torchdiffeq -tqdm>=4.65.0 -transformers -x_transformers>=1.31.14 - -# icefall -kaldialign -lhotse -tensorboard -bigvganinference -sentencepiece -sherpa-onnx -k2 - -# semantic experiment -s3tokenizer diff --git a/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py deleted file mode 100644 index 7d42a00a5..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Callable, Dict, List, Sequence, Union - -import torch -from lhotse import validate -from lhotse.cut import CutSet -from lhotse.dataset.collation import collate_audio -from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures -from lhotse.utils import ifnone - - -class SpeechSynthesisDataset(torch.utils.data.Dataset): - """ - The PyTorch Dataset for the speech synthesis task. - Each item in this dataset is a dict of: - - .. code-block:: - - { - 'audio': (B x NumSamples) float tensor - 'features': (B x NumFrames x NumFeatures) float tensor - 'audio_lens': (B, ) int tensor - 'features_lens': (B, ) int tensor - 'text': List[str] of len B # when return_text=True - 'tokens': List[List[str]] # when return_tokens=True - 'speakers': List[str] of len B # when return_spk_ids=True - 'cut': List of Cuts # when return_cuts=True - } - """ - - def __init__( - self, - cut_transforms: List[Callable[[CutSet], CutSet]] = None, - feature_input_strategy: BatchIO = PrecomputedFeatures(), - feature_transforms: Union[Sequence[Callable], Callable] = None, - return_text: bool = True, - return_tokens: bool = False, - return_spk_ids: bool = False, - return_cuts: bool = False, - ) -> None: - super().__init__() - - self.cut_transforms = ifnone(cut_transforms, []) - self.feature_input_strategy = feature_input_strategy - - self.return_text = return_text - self.return_tokens = return_tokens - self.return_spk_ids = return_spk_ids - self.return_cuts = return_cuts - - if feature_transforms is None: - feature_transforms = [] - elif not isinstance(feature_transforms, Sequence): - feature_transforms = [feature_transforms] - - assert all( - isinstance(transform, Callable) for transform in feature_transforms - ), "Feature transforms must be Callable" - self.feature_transforms = feature_transforms - - def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: - validate_for_tts(cuts) - - for transform in self.cut_transforms: - cuts = transform(cuts) - - # audio, audio_lens = collate_audio(cuts) - features, features_lens = self.feature_input_strategy(cuts) - - for transform in self.feature_transforms: - features = transform(features) - - batch = { - # "audio": audio, - "features": features, - # "audio_lens": audio_lens, - "features_lens": features_lens, - } - - if self.return_text: - # use normalized text - # text = [cut.supervisions[0].normalized_text for cut in cuts] - text = [cut.supervisions[0].text for cut in cuts] - batch["text"] = text - - if self.return_tokens and "speech_tokens" in cuts[0].supervisions[0].custom: - # tokens = [cut.tokens for cut in cuts] - # tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] - tokens = [cut.supervisions[0].custom["speech_tokens"] for cut in cuts] - # change str into list - tokens = [list(map(int, token.split())) for token in tokens] - batch["tokens"] = tokens - - if self.return_spk_ids: - batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] - - if self.return_cuts: - batch["cut"] = [cut for cut in cuts] - - return batch - - -def validate_for_tts(cuts: CutSet) -> None: - validate(cuts) - for cut in cuts: - assert ( - len(cut.supervisions) == 1 - ), "Only the Cuts with single supervision are supported." diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py deleted file mode 100755 index 5333b3f27..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ /dev/null @@ -1,1233 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo) -# Copyright 2023 (authors: Feiteng Li) -# Copyright 2024 (authors: Yuekai Zhang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -# docker: ghcr.io/swivid/f5-tts:main -# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html -# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece - -world_size=8 -exp_dir=exp/f5-tts-small -python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ - --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ - --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ - --exp-dir ${exp_dir} --world-size ${world_size} - -# command for training with cosyvoice semantic token -exp_dir=exp/f5-tts-cosyvoice -python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ - --base-lr 1e-4 --warmup-steps 20000 --average-period 0 \ - --num-epochs 10 --start-epoch 1 --start-batch 0 \ - --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ - --exp-dir ${exp_dir} --world-size ${world_size} \ - --decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True -""" - -import argparse -import copy -import logging -import os -import random -import warnings -from contextlib import nullcontext -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse import CutSet -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model.cfm import CFM -from model.dit import DiT -from model.utils import convert_char_to_pinyin -from torch import Tensor -from torch.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import LinearLR, SequentialLR -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import TtsDataModule -from utils import MetricsTracker - -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, setup_logger, str2bool # MetricsTracker - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--decoder-dim", - type=int, - default=1024, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=16, - help="Number of attention heads in the Decoder layers.", - ) - - parser.add_argument( - "--num-decoder-layers", - type=int, - default=22, - help="Number of Decoder layers.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="exp/f5", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="f5-tts/vocab.txt", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--pretrained-model-path", - type=str, - default=None, - help="Path to file", - ) - - parser.add_argument( - "--optimizer-name", - type=str, - default="AdamW", - help="The optimizer.", - ) - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - parser.add_argument( - "--warmup-steps", - type=int, - default=200, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--decay-steps", - type=int, - default=1000000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=10000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train %% save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - parser.add_argument( - "--valid-interval", - type=int, - default=10000, - help="""Run validation if batch_idx %% valid_interval is 0.""", - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=0, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--accumulate-grad-steps", - type=int, - default=1, - help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. - """, - ) - - parser.add_argument( - "--dtype", - type=str, - default="bfloat16", - help="Training dtype: float32 bfloat16 float16.", - ) - - parser.add_argument( - "--filter-min-duration", - type=float, - default=0.0, - help="Keep only utterances with duration > this.", - ) - - parser.add_argument( - "--filter-max-duration", - type=float, - default=20.0, - help="Keep only utterances with duration < this.", - ) - - parser.add_argument( - "--oom-check", - type=str2bool, - default=False, - help="perform OOM check on dataloader batches before starting training.", - ) - - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 100, - "reset_interval": 200, - "valid_interval": 10000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_tokenizer(vocab_file_path: str): - """ - tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - - "char" for char-wise tokenizer, need .txt vocab_file - - "byte" for utf-8 tokenizer - - "custom" if you're directly passing in a path to the vocab.txt you want to use - vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - - if use "char", derived from unfiltered character & symbol counts of custom dataset - - if use "byte", set to 256 (unicode byte range) - """ - with open(vocab_file_path, "r", encoding="utf-8") as f: - vocab_char_map = {} - for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i - vocab_size = len(vocab_char_map) - - return vocab_char_map, vocab_size - - -def get_model(params): - if params.use_cosyvoice_semantic_token: - # https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36 - vocab_char_map, vocab_size = None, 6561 - else: - vocab_char_map, vocab_size = get_tokenizer(params.tokens) - # bigvgan 100 dim features - n_mel_channels = 100 - n_fft = 1024 - sampling_rate = 24_000 - hop_length = 256 - win_length = 1024 - - model_cfg = { - "dim": params.decoder_dim, - "depth": params.num_decoder_layers, - "heads": params.nhead, - "ff_mult": 2, - "text_dim": 512, - "conv_layers": 4, - "checkpoint_activations": False, - } - model = CFM( - transformer=DiT( - **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels - ), - mel_spec_kwargs=dict( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=sampling_rate, - mel_spec_type="bigvgan", - ), - odeint_kwargs=dict( - method="euler", - ), - vocab_char_map=vocab_char_map, - ) - return model - - -def load_F5_TTS_pretrained_checkpoint( - model, ckpt_path, device: str = "cpu", dtype=torch.float32 -): - checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) - if "ema_model_state_dict" in checkpoint: - checkpoint["model_state_dict"] = { - k.replace("ema_model.", ""): v - for k, v in checkpoint["ema_model_state_dict"].items() - if k not in ["initted", "step"] - } - - # patch for backward compatibility, 305e3ea - for key in [ - "mel_spec.mel_stft.mel_scale.fb", - "mel_spec.mel_stft.spectrogram.window", - ]: - if key in checkpoint["model_state_dict"]: - del checkpoint["model_state_dict"][key] - model.load_state_dict(checkpoint["model_state_dict"]) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - if isinstance(model, DDP): - raise ValueError("load_checkpoint before DDP") - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def interpolate_tokens(cosy_tokens, pad_token=-1): - """Interpolate cosyvoice tokens to match bigvgan frames length""" - # cosyvoice, 25 tokens/sec - # bigvgan sample_rate/hop_length 24000/256 frames/sec - # For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length - # We choose 4,4,4,3 to match 15 frames - three, two = [pad_token] * 3, [pad_token] * 2 - return [ - x - for i, e in enumerate(cosy_tokens) - for x in ([e] + three if i % 4 < 3 else [e] + two) - ] - - -def prepare_input( - batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool -): - """Parse batch data""" - mel_spec = batch["features"] - mel_lengths = batch["features_lens"] - - if use_cosyvoice_semantic_token: - semantic_tokens = [] - for i in range(len(batch["tokens"])): - tokens = batch["tokens"][i] - tokens = interpolate_tokens(tokens) - semantic_tokens.append(tokens) - # pad to the same length, B,T, with pad value -1 - max_len = max([len(tokens) for tokens in semantic_tokens]) - text_inputs = torch.full( - (len(semantic_tokens), max_len), -1, dtype=torch.long - ).to(device) - for i, tokens in enumerate(semantic_tokens): - text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) - else: - text_inputs = batch["text"] - text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) - - return text_inputs, mel_spec.to(device), mel_lengths.to(device) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - (text_inputs, mel_spec, mel_lengths) = prepare_input( - batch, - device=device, - use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token, - ) - # at entry, TextTokens is (N, P) - - with torch.set_grad_enabled(is_training): - loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths) - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["samples"] = mel_lengths.size(0) - - info["loss"] = loss.detach().cpu().item() * info["samples"] - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - if world_size > 1: - tot_loss.reduce(loss.device) - loss_value = tot_loss["loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - rng: random.Random, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - rng: - Random for selecting. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - tot_loss = MetricsTracker() - iter_dl = iter(train_dl) - - dtype, enabled = torch.float32, False - if params.dtype in ["bfloat16", "bf16"]: - dtype, enabled = torch.bfloat16, True - elif params.dtype in ["float16", "fp16"]: - dtype, enabled = torch.float16, True - - batch_idx = 0 - while True: - try: - batch = next(iter_dl) - except StopIteration: - logging.info("Reaches end of dataloader.") - break - - batch_idx += 1 - - params.batch_idx_train += 1 - batch_size = len(batch["text"]) - - try: - with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled): - loss, loss_info = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - batch=batch, - is_training=True, - ) - - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( - 1 / params.reset_interval - ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - scaler.scale(loss).backward() - if params.batch_idx_train >= params.accumulate_grad_steps: - if params.batch_idx_train % params.accumulate_grad_steps == 0: - - # Unscales the gradients of optimizer's assigned params in-place - scaler.unscale_(optimizer) - # Since the gradients of optimizer's assigned params are unscaled, clips as usual: - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - # loss.backward() - # optimizer.step() - - for k in range(params.accumulate_grad_steps): - scheduler.step() - - set_batch_count(model, params.batch_idx_train) - except: # noqa - display_and_save_batch(batch, params=params) - raise - - if params.average_period > 0: - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - # Perform Operation in rank 0 - if rank == 0: - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - # Perform Operation in rank 0 - if rank == 0: - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = ( - scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 - ) - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, train_loss[{loss_info}], " - f"batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - + ( - f", grad_scale: {cur_grad_scale}" - if params.dtype in ["float16", "fp16"] - else "" - ) - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train, - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.dtype in ["float16", "fp16"]: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if params.batch_idx_train % params.valid_interval == 0: - # Calculate validation loss in Rank 0 - model.eval() - logging.info("Computing validation loss") - with torch.amp.autocast("cuda", dtype=dtype): - valid_info = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - model.train() - - loss_value = tot_loss["loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def filter_short_and_long_utterances( - cuts: CutSet, min_duration: float, max_duration: float -) -> CutSet: - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 0.6 second and 20 seconds - if c.duration < min_duration or c.duration > max_duration: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - cuts = cuts.filter(remove_short_and_long_utt) - - return cuts - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - rng = random.Random(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - torch.backends.cudnn.allow_tf32 = True - torch.backends.cuda.matmul.allow_tf32 = True - - logging.info(f"Device: {device}") - tokenizer = get_tokenizer(params.tokens) - logging.info(params) - - logging.info("About to create model") - - model = get_model(params) - - if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") - if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: - model = load_F5_TTS_pretrained_checkpoint( - model, params.pretrained_model_path - ) - else: - _ = load_checkpoint( - params.pretrained_model_path, - model=model, - ) - - model = model.to(device) - - with open(f"{params.exp_dir}/model.txt", "w") as f: - print(model) - print(model, file=f) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0 and params.average_period > 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=False) - - model_parameters = model.parameters() - - optimizer = torch.optim.AdamW( - model_parameters, - lr=params.base_lr, - betas=(0.9, 0.95), - weight_decay=1e-2, - eps=1e-8, - ) - - warmup_scheduler = LinearLR( - optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps - ) - decay_scheduler = LinearLR( - optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps - ) - scheduler = SequentialLR( - optimizer, - schedulers=[warmup_scheduler, decay_scheduler], - milestones=[params.warmup_steps], - ) - - optimizer.zero_grad() - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.inf_check: - register_inf_check_hooks(model) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - dataset = TtsDataModule(args) - train_cuts = dataset.train_cuts() - valid_cuts = dataset.valid_cuts() - - train_cuts = filter_short_and_long_utterances( - train_cuts, params.filter_min_duration, params.filter_max_duration - ) - valid_cuts = filter_short_and_long_utterances( - valid_cuts, params.filter_min_duration, params.filter_max_duration - ) - - train_dl = dataset.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - valid_dl = dataset.valid_dataloaders(valid_cuts) - - if params.oom_check: - scan_pessimistic_batches_for_oom( - model=model, - tokenizer=tokenizer, - train_dl=train_dl, - optimizer=optimizer, - params=params, - ) - - scaler = GradScaler( - "cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 - ) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - rng=rng, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - tokenizer, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - dtype = torch.float32 - if params.dtype in ["bfloat16", "bf16"]: - dtype = torch.bfloat16 - elif params.dtype in ["float16", "fp16"]: - dtype = torch.float16 - - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - print(batch.keys()) - try: - with torch.amp.autocast("cuda", dtype=dtype): - loss, loss_info = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - batch=batch, - is_training=True, - ) - loss.backward(retain_graph=True) - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - TtsDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py deleted file mode 100644 index eab7588b7..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset, - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from speech_synthesis import SpeechSynthesisDataset # noqa F401 -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class TtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - parser.add_argument( - "--prefix", - type=str, - default="wenetspeech4tts", - help="prefix of the manifest file", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - raise NotImplementedError( - "On-the-fly feature extraction is not implemented yet." - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - pin_memory=True, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - raise NotImplementedError( - "On-the-fly feature extraction is not implemented yet." - ) - else: - validate = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=True, - pin_memory=True, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - raise NotImplementedError( - "On-the-fly feature extraction is not implemented yet." - ) - else: - test = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz" - ) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/utils.py deleted file mode 120000 index ceaaea196..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt b/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt deleted file mode 100644 index 93f8b48b2..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt +++ /dev/null @@ -1,2545 +0,0 @@ - -! -" -# -$ -% -& -' -( -) -* -+ -, -- -. -/ -0 -1 -2 -3 -4 -5 -6 -7 -8 -9 -: -; -= -> -? -@ -A -B -C -D -E -F -G -H -I -J -K -L -M -N -O -P -Q -R -S -T -U -V -W -X -Y -Z -[ -\ -] -_ -a -a1 -ai1 -ai2 -ai3 -ai4 -an1 -an3 -an4 -ang1 -ang2 -ang4 -ao1 -ao2 -ao3 -ao4 -b -ba -ba1 -ba2 -ba3 -ba4 -bai1 -bai2 -bai3 -bai4 -ban1 -ban2 -ban3 -ban4 -bang1 -bang2 -bang3 -bang4 -bao1 -bao2 -bao3 -bao4 -bei -bei1 -bei2 -bei3 -bei4 -ben1 -ben2 -ben3 -ben4 -beng -beng1 -beng2 -beng3 -beng4 -bi1 -bi2 -bi3 -bi4 -bian1 -bian2 -bian3 -bian4 -biao1 -biao2 -biao3 -bie1 -bie2 -bie3 -bie4 -bin1 -bin4 -bing1 -bing2 -bing3 -bing4 -bo -bo1 -bo2 -bo3 -bo4 -bu2 -bu3 -bu4 -c -ca1 -cai1 -cai2 -cai3 -cai4 -can1 -can2 -can3 -can4 -cang1 -cang2 -cao1 -cao2 -cao3 -ce4 -cen1 -cen2 -ceng1 -ceng2 -ceng4 -cha1 -cha2 -cha3 -cha4 -chai1 -chai2 -chan1 -chan2 -chan3 -chan4 -chang1 -chang2 -chang3 -chang4 -chao1 -chao2 -chao3 -che1 -che2 -che3 -che4 -chen1 -chen2 -chen3 -chen4 -cheng1 -cheng2 -cheng3 -cheng4 -chi1 -chi2 -chi3 -chi4 -chong1 -chong2 -chong3 -chong4 -chou1 -chou2 -chou3 -chou4 -chu1 -chu2 -chu3 -chu4 -chua1 -chuai1 -chuai2 -chuai3 -chuai4 -chuan1 -chuan2 -chuan3 -chuan4 -chuang1 -chuang2 -chuang3 -chuang4 -chui1 -chui2 -chun1 -chun2 -chun3 -chuo1 -chuo4 -ci1 -ci2 -ci3 -ci4 -cong1 -cong2 -cou4 -cu1 -cu4 -cuan1 -cuan2 -cuan4 -cui1 -cui3 -cui4 -cun1 -cun2 -cun4 -cuo1 -cuo2 -cuo4 -d -da -da1 -da2 -da3 -da4 -dai1 -dai2 -dai3 -dai4 -dan1 -dan2 -dan3 -dan4 -dang1 -dang2 -dang3 -dang4 -dao1 -dao2 -dao3 -dao4 -de -de1 -de2 -dei3 -den4 -deng1 -deng2 -deng3 -deng4 -di1 -di2 -di3 -di4 -dia3 -dian1 -dian2 -dian3 -dian4 -diao1 -diao3 -diao4 -die1 -die2 -die4 -ding1 -ding2 -ding3 -ding4 -diu1 -dong1 -dong3 -dong4 -dou1 -dou2 -dou3 -dou4 -du1 -du2 -du3 -du4 -duan1 -duan2 -duan3 -duan4 -dui1 -dui4 -dun1 -dun3 -dun4 -duo1 -duo2 -duo3 -duo4 -e -e1 -e2 -e3 -e4 -ei2 -en1 -en4 -er -er2 -er3 -er4 -f -fa1 -fa2 -fa3 -fa4 -fan1 -fan2 -fan3 -fan4 -fang1 -fang2 -fang3 -fang4 -fei1 -fei2 -fei3 -fei4 -fen1 -fen2 -fen3 -fen4 -feng1 -feng2 -feng3 -feng4 -fo2 -fou2 -fou3 -fu1 -fu2 -fu3 -fu4 -g -ga1 -ga2 -ga3 -ga4 -gai1 -gai2 -gai3 -gai4 -gan1 -gan2 -gan3 -gan4 -gang1 -gang2 -gang3 -gang4 -gao1 -gao2 -gao3 -gao4 -ge1 -ge2 -ge3 -ge4 -gei2 -gei3 -gen1 -gen2 -gen3 -gen4 -geng1 -geng3 -geng4 -gong1 -gong3 -gong4 -gou1 -gou2 -gou3 -gou4 -gu -gu1 -gu2 -gu3 -gu4 -gua1 -gua2 -gua3 -gua4 -guai1 -guai2 -guai3 -guai4 -guan1 -guan2 -guan3 -guan4 -guang1 -guang2 -guang3 -guang4 -gui1 -gui2 -gui3 -gui4 -gun3 -gun4 -guo1 -guo2 -guo3 -guo4 -h -ha1 -ha2 -ha3 -hai1 -hai2 -hai3 -hai4 -han1 -han2 -han3 -han4 -hang1 -hang2 -hang4 -hao1 -hao2 -hao3 -hao4 -he1 -he2 -he4 -hei1 -hen2 -hen3 -hen4 -heng1 -heng2 -heng4 -hong1 -hong2 -hong3 -hong4 -hou1 -hou2 -hou3 -hou4 -hu1 -hu2 -hu3 -hu4 -hua1 -hua2 -hua4 -huai2 -huai4 -huan1 -huan2 -huan3 -huan4 -huang1 -huang2 -huang3 -huang4 -hui1 -hui2 -hui3 -hui4 -hun1 -hun2 -hun4 -huo -huo1 -huo2 -huo3 -huo4 -i -j -ji1 -ji2 -ji3 -ji4 -jia -jia1 -jia2 -jia3 -jia4 -jian1 -jian2 -jian3 -jian4 -jiang1 -jiang2 -jiang3 -jiang4 -jiao1 -jiao2 -jiao3 -jiao4 -jie1 -jie2 -jie3 -jie4 -jin1 -jin2 -jin3 -jin4 -jing1 -jing2 -jing3 -jing4 -jiong3 -jiu1 -jiu2 -jiu3 -jiu4 -ju1 -ju2 -ju3 -ju4 -juan1 -juan2 -juan3 -juan4 -jue1 -jue2 -jue4 -jun1 -jun4 -k -ka1 -ka2 -ka3 -kai1 -kai2 -kai3 -kai4 -kan1 -kan2 -kan3 -kan4 -kang1 -kang2 -kang4 -kao1 -kao2 -kao3 -kao4 -ke1 -ke2 -ke3 -ke4 -ken3 -keng1 -kong1 -kong3 -kong4 -kou1 -kou2 -kou3 -kou4 -ku1 -ku2 -ku3 -ku4 -kua1 -kua3 -kua4 -kuai3 -kuai4 -kuan1 -kuan2 -kuan3 -kuang1 -kuang2 -kuang4 -kui1 -kui2 -kui3 -kui4 -kun1 -kun3 -kun4 -kuo4 -l -la -la1 -la2 -la3 -la4 -lai2 -lai4 -lan2 -lan3 -lan4 -lang1 -lang2 -lang3 -lang4 -lao1 -lao2 -lao3 -lao4 -le -le1 -le4 -lei -lei1 -lei2 -lei3 -lei4 -leng1 -leng2 -leng3 -leng4 -li -li1 -li2 -li3 -li4 -lia3 -lian2 -lian3 -lian4 -liang2 -liang3 -liang4 -liao1 -liao2 -liao3 -liao4 -lie1 -lie2 -lie3 -lie4 -lin1 -lin2 -lin3 -lin4 -ling2 -ling3 -ling4 -liu1 -liu2 -liu3 -liu4 -long1 -long2 -long3 -long4 -lou1 -lou2 -lou3 -lou4 -lu1 -lu2 -lu3 -lu4 -luan2 -luan3 -luan4 -lun1 -lun2 -lun4 -luo1 -luo2 -luo3 -luo4 -lv2 -lv3 -lv4 -lve3 -lve4 -m -ma -ma1 -ma2 -ma3 -ma4 -mai2 -mai3 -mai4 -man1 -man2 -man3 -man4 -mang2 -mang3 -mao1 -mao2 -mao3 -mao4 -me -mei2 -mei3 -mei4 -men -men1 -men2 -men4 -meng -meng1 -meng2 -meng3 -meng4 -mi1 -mi2 -mi3 -mi4 -mian2 -mian3 -mian4 -miao1 -miao2 -miao3 -miao4 -mie1 -mie4 -min2 -min3 -ming2 -ming3 -ming4 -miu4 -mo1 -mo2 -mo3 -mo4 -mou1 -mou2 -mou3 -mu2 -mu3 -mu4 -n -n2 -na1 -na2 -na3 -na4 -nai2 -nai3 -nai4 -nan1 -nan2 -nan3 -nan4 -nang1 -nang2 -nang3 -nao1 -nao2 -nao3 -nao4 -ne -ne2 -ne4 -nei3 -nei4 -nen4 -neng2 -ni1 -ni2 -ni3 -ni4 -nian1 -nian2 -nian3 -nian4 -niang2 -niang4 -niao2 -niao3 -niao4 -nie1 -nie4 -nin2 -ning2 -ning3 -ning4 -niu1 -niu2 -niu3 -niu4 -nong2 -nong4 -nou4 -nu2 -nu3 -nu4 -nuan3 -nuo2 -nuo4 -nv2 -nv3 -nve4 -o -o1 -o2 -ou1 -ou2 -ou3 -ou4 -p -pa1 -pa2 -pa4 -pai1 -pai2 -pai3 -pai4 -pan1 -pan2 -pan4 -pang1 -pang2 -pang4 -pao1 -pao2 -pao3 -pao4 -pei1 -pei2 -pei4 -pen1 -pen2 -pen4 -peng1 -peng2 -peng3 -peng4 -pi1 -pi2 -pi3 -pi4 -pian1 -pian2 -pian4 -piao1 -piao2 -piao3 -piao4 -pie1 -pie2 -pie3 -pin1 -pin2 -pin3 -pin4 -ping1 -ping2 -po1 -po2 -po3 -po4 -pou1 -pu1 -pu2 -pu3 -pu4 -q -qi1 -qi2 -qi3 -qi4 -qia1 -qia3 -qia4 -qian1 -qian2 -qian3 -qian4 -qiang1 -qiang2 -qiang3 -qiang4 -qiao1 -qiao2 -qiao3 -qiao4 -qie1 -qie2 -qie3 -qie4 -qin1 -qin2 -qin3 -qin4 -qing1 -qing2 -qing3 -qing4 -qiong1 -qiong2 -qiu1 -qiu2 -qiu3 -qu1 -qu2 -qu3 -qu4 -quan1 -quan2 -quan3 -quan4 -que1 -que2 -que4 -qun2 -r -ran2 -ran3 -rang1 -rang2 -rang3 -rang4 -rao2 -rao3 -rao4 -re2 -re3 -re4 -ren2 -ren3 -ren4 -reng1 -reng2 -ri4 -rong1 -rong2 -rong3 -rou2 -rou4 -ru2 -ru3 -ru4 -ruan2 -ruan3 -rui3 -rui4 -run4 -ruo4 -s -sa1 -sa2 -sa3 -sa4 -sai1 -sai4 -san1 -san2 -san3 -san4 -sang1 -sang3 -sang4 -sao1 -sao2 -sao3 -sao4 -se4 -sen1 -seng1 -sha1 -sha2 -sha3 -sha4 -shai1 -shai2 -shai3 -shai4 -shan1 -shan3 -shan4 -shang -shang1 -shang3 -shang4 -shao1 -shao2 -shao3 -shao4 -she1 -she2 -she3 -she4 -shei2 -shen1 -shen2 -shen3 -shen4 -sheng1 -sheng2 -sheng3 -sheng4 -shi -shi1 -shi2 -shi3 -shi4 -shou1 -shou2 -shou3 -shou4 -shu1 -shu2 -shu3 -shu4 -shua1 -shua2 -shua3 -shua4 -shuai1 -shuai3 -shuai4 -shuan1 -shuan4 -shuang1 -shuang3 -shui2 -shui3 -shui4 -shun3 -shun4 -shuo1 -shuo4 -si1 -si2 -si3 -si4 -song1 -song3 -song4 -sou1 -sou3 -sou4 -su1 -su2 -su4 -suan1 -suan4 -sui1 -sui2 -sui3 -sui4 -sun1 -sun3 -suo -suo1 -suo2 -suo3 -t -ta1 -ta2 -ta3 -ta4 -tai1 -tai2 -tai4 -tan1 -tan2 -tan3 -tan4 -tang1 -tang2 -tang3 -tang4 -tao1 -tao2 -tao3 -tao4 -te4 -teng2 -ti1 -ti2 -ti3 -ti4 -tian1 -tian2 -tian3 -tiao1 -tiao2 -tiao3 -tiao4 -tie1 -tie2 -tie3 -tie4 -ting1 -ting2 -ting3 -tong1 -tong2 -tong3 -tong4 -tou -tou1 -tou2 -tou4 -tu1 -tu2 -tu3 -tu4 -tuan1 -tuan2 -tui1 -tui2 -tui3 -tui4 -tun1 -tun2 -tun4 -tuo1 -tuo2 -tuo3 -tuo4 -u -v -w -wa -wa1 -wa2 -wa3 -wa4 -wai1 -wai3 -wai4 -wan1 -wan2 -wan3 -wan4 -wang1 -wang2 -wang3 -wang4 -wei1 -wei2 -wei3 -wei4 -wen1 -wen2 -wen3 -wen4 -weng1 -weng4 -wo1 -wo2 -wo3 -wo4 -wu1 -wu2 -wu3 -wu4 -x -xi1 -xi2 -xi3 -xi4 -xia1 -xia2 -xia4 -xian1 -xian2 -xian3 -xian4 -xiang1 -xiang2 -xiang3 -xiang4 -xiao1 -xiao2 -xiao3 -xiao4 -xie1 -xie2 -xie3 -xie4 -xin1 -xin2 -xin4 -xing1 -xing2 -xing3 -xing4 -xiong1 -xiong2 -xiu1 -xiu3 -xiu4 -xu -xu1 -xu2 -xu3 -xu4 -xuan1 -xuan2 -xuan3 -xuan4 -xue1 -xue2 -xue3 -xue4 -xun1 -xun2 -xun4 -y -ya -ya1 -ya2 -ya3 -ya4 -yan1 -yan2 -yan3 -yan4 -yang1 -yang2 -yang3 -yang4 -yao1 -yao2 -yao3 -yao4 -ye1 -ye2 -ye3 -ye4 -yi -yi1 -yi2 -yi3 -yi4 -yin1 -yin2 -yin3 -yin4 -ying1 -ying2 -ying3 -ying4 -yo1 -yong1 -yong2 -yong3 -yong4 -you1 -you2 -you3 -you4 -yu1 -yu2 -yu3 -yu4 -yuan1 -yuan2 -yuan3 -yuan4 -yue1 -yue4 -yun1 -yun2 -yun3 -yun4 -z -za1 -za2 -za3 -zai1 -zai3 -zai4 -zan1 -zan2 -zan3 -zan4 -zang1 -zang4 -zao1 -zao2 -zao3 -zao4 -ze2 -ze4 -zei2 -zen3 -zeng1 -zeng4 -zha1 -zha2 -zha3 -zha4 -zhai1 -zhai2 -zhai3 -zhai4 -zhan1 -zhan2 -zhan3 -zhan4 -zhang1 -zhang2 -zhang3 -zhang4 -zhao1 -zhao2 -zhao3 -zhao4 -zhe -zhe1 -zhe2 -zhe3 -zhe4 -zhen1 -zhen2 -zhen3 -zhen4 -zheng1 -zheng2 -zheng3 -zheng4 -zhi1 -zhi2 -zhi3 -zhi4 -zhong1 -zhong2 -zhong3 -zhong4 -zhou1 -zhou2 -zhou3 -zhou4 -zhu1 -zhu2 -zhu3 -zhu4 -zhua1 -zhua2 -zhua3 -zhuai1 -zhuai3 -zhuai4 -zhuan1 -zhuan2 -zhuan3 -zhuan4 -zhuang1 -zhuang4 -zhui1 -zhui4 -zhun1 -zhun2 -zhun3 -zhuo1 -zhuo2 -zi -zi1 -zi2 -zi3 -zi4 -zong1 -zong2 -zong3 -zong4 -zou1 -zou2 -zou3 -zou4 -zu1 -zu2 -zu3 -zuan1 -zuan3 -zuan4 -zui2 -zui3 -zui4 -zun1 -zuo -zuo1 -zuo2 -zuo3 -zuo4 -{ -~ -¡ -¢ -£ -¥ -§ -¨ -© -« -® -¯ -° -± -² -³ -´ -µ -· -¹ -º -» -¼ -½ -¾ -¿ -À -Á - -à -Ä -Å -Æ -Ç -È -É -Ê -Í -Î -Ñ -Ó -Ö -× -Ø -Ú -Ü -Ý -Þ -ß -à -á -â -ã -ä -å -æ -ç -è -é -ê -ë -ì -í -î -ï -ð -ñ -ò -ó -ô -õ -ö -ø -ù -ú -û -ü -ý -Ā -ā -ă -ą -ć -Č -č -Đ -đ -ē -ė -ę -ě -ĝ -ğ -ħ -ī -į -İ -ı -Ł -ł -ń -ņ -ň -ŋ -Ō -ō -ő -œ -ř -Ś -ś -Ş -ş -Š -š -Ť -ť -ũ -ū -ź -Ż -ż -Ž -ž -ơ -ư -ǎ -ǐ -ǒ -ǔ -ǚ -ș -ț -ɑ -ɔ -ɕ -ə -ɛ -ɜ -ɡ -ɣ -ɪ -ɫ -ɴ -ɹ -ɾ -ʃ -ʊ -ʌ -ʒ -ʔ -ʰ -ʷ -ʻ -ʾ -ʿ -ˈ -ː -˙ -˜ -ˢ -́ -̅ -Α -Β -Δ -Ε -Θ -Κ -Λ -Μ -Ξ -Π -Σ -Τ -Φ -Χ -Ψ -Ω -ά -έ -ή -ί -α -β -γ -δ -ε -ζ -η -θ -ι -κ -λ -μ -ν -ξ -ο -π -ρ -ς -σ -τ -υ -φ -χ -ψ -ω -ϊ -ό -ύ -ώ -ϕ -ϵ -Ё -А -Б -В -Г -Д -Е -Ж -З -И -Й -К -Л -М -Н -О -П -Р -С -Т -У -Ф -Х -Ц -Ч -Ш -Щ -Ы -Ь -Э -Ю -Я -а -б -в -г -д -е -ж -з -и -й -к -л -м -н -о -п -р -с -т -у -ф -х -ц -ч -ш -щ -ъ -ы -ь -э -ю -я -ё -і -ְ -ִ -ֵ -ֶ -ַ -ָ -ֹ -ּ -־ -ׁ -א -ב -ג -ד -ה -ו -ז -ח -ט -י -כ -ל -ם -מ -ן -נ -ס -ע -פ -ק -ר -ש -ת -أ -ب -ة -ت -ج -ح -د -ر -ز -س -ص -ط -ع -ق -ك -ل -م -ن -ه -و -ي -َ -ُ -ِ -ْ -ก -ข -ง -จ -ต -ท -น -ป -ย -ร -ว -ส -ห -อ -ฮ -ั -า -ี -ึ -โ -ใ -ไ -่ -้ -์ -ḍ -Ḥ -ḥ -ṁ -ṃ -ṅ -ṇ -Ṛ -ṛ -Ṣ -ṣ -Ṭ -ṭ -ạ -ả -Ấ -ấ -ầ -ậ -ắ -ằ -ẻ -ẽ -ế -ề -ể -ễ -ệ -ị -ọ -ỏ -ố -ồ -ộ -ớ -ờ -ở -ụ -ủ -ứ -ữ -ἀ -ἁ -Ἀ -ἐ -ἔ -ἰ -ἱ -ὀ -ὁ -ὐ -ὲ -ὸ -ᾶ -᾽ -ῆ -ῇ -ῶ -‎ -‑ -‒ -– -— -― -‖ -† -‡ -• -… -‧ -‬ -′ -″ -⁄ -⁡ -⁰ -⁴ -⁵ -⁶ -⁷ -⁸ -⁹ -₁ -₂ -₃ -€ -₱ -₹ -₽ -℃ -ℏ -ℓ -№ -ℝ -™ -⅓ -⅔ -⅛ -→ -∂ -∈ -∑ -− -∗ -√ -∞ -∫ -≈ -≠ -≡ -≤ -≥ -⋅ -⋯ -█ -♪ -⟨ -⟩ -、 -。 -《 -》 -「 -」 -【 -】 -あ -う -え -お -か -が -き -ぎ -く -ぐ -け -げ -こ -ご -さ -し -じ -す -ず -せ -ぜ -そ -ぞ -た -だ -ち -っ -つ -で -と -ど -な -に -ね -の -は -ば -ひ -ぶ -へ -べ -ま -み -む -め -も -ゃ -や -ゆ -ょ -よ -ら -り -る -れ -ろ -わ -を -ん -ァ -ア -ィ -イ -ウ -ェ -エ -オ -カ -ガ -キ -ク -ケ -ゲ -コ -ゴ -サ -ザ -シ -ジ -ス -ズ -セ -ゾ -タ -ダ -チ -ッ -ツ -テ -デ -ト -ド -ナ -ニ -ネ -ノ -バ -パ -ビ -ピ -フ -プ -ヘ -ベ -ペ -ホ -ボ -ポ -マ -ミ -ム -メ -モ -ャ -ヤ -ュ -ユ -ョ -ヨ -ラ -リ -ル -レ -ロ -ワ -ン -・ -ー -ㄋ -ㄍ -ㄎ -ㄏ -ㄓ -ㄕ -ㄚ -ㄜ -ㄟ -ㄤ -ㄥ -ㄧ -ㄱ -ㄴ -ㄷ -ㄹ -ㅁ -ㅂ -ㅅ -ㅈ -ㅍ -ㅎ -ㅏ -ㅓ -ㅗ -ㅜ -ㅡ -ㅣ -㗎 -가 -각 -간 -갈 -감 -갑 -갓 -갔 -강 -같 -개 -거 -건 -걸 -겁 -것 -겉 -게 -겠 -겨 -결 -겼 -경 -계 -고 -곤 -골 -곱 -공 -과 -관 -광 -교 -구 -국 -굴 -귀 -귄 -그 -근 -글 -금 -기 -긴 -길 -까 -깍 -깔 -깜 -깨 -께 -꼬 -꼭 -꽃 -꾸 -꿔 -끔 -끗 -끝 -끼 -나 -난 -날 -남 -납 -내 -냐 -냥 -너 -넘 -넣 -네 -녁 -년 -녕 -노 -녹 -놀 -누 -눈 -느 -는 -늘 -니 -님 -닙 -다 -닥 -단 -달 -닭 -당 -대 -더 -덕 -던 -덥 -데 -도 -독 -동 -돼 -됐 -되 -된 -될 -두 -둑 -둥 -드 -들 -등 -디 -따 -딱 -딸 -땅 -때 -떤 -떨 -떻 -또 -똑 -뚱 -뛰 -뜻 -띠 -라 -락 -란 -람 -랍 -랑 -래 -랜 -러 -런 -럼 -렇 -레 -려 -력 -렵 -렸 -로 -록 -롬 -루 -르 -른 -를 -름 -릉 -리 -릴 -림 -마 -막 -만 -많 -말 -맑 -맙 -맛 -매 -머 -먹 -멍 -메 -면 -명 -몇 -모 -목 -몸 -못 -무 -문 -물 -뭐 -뭘 -미 -민 -밌 -밑 -바 -박 -밖 -반 -받 -발 -밤 -밥 -방 -배 -백 -밸 -뱀 -버 -번 -벌 -벚 -베 -벼 -벽 -별 -병 -보 -복 -본 -볼 -봐 -봤 -부 -분 -불 -비 -빔 -빛 -빠 -빨 -뼈 -뽀 -뿅 -쁘 -사 -산 -살 -삼 -샀 -상 -새 -색 -생 -서 -선 -설 -섭 -섰 -성 -세 -셔 -션 -셨 -소 -속 -손 -송 -수 -숙 -순 -술 -숫 -숭 -숲 -쉬 -쉽 -스 -슨 -습 -슷 -시 -식 -신 -실 -싫 -심 -십 -싶 -싸 -써 -쓰 -쓴 -씌 -씨 -씩 -씬 -아 -악 -안 -않 -알 -야 -약 -얀 -양 -얘 -어 -언 -얼 -엄 -업 -없 -었 -엉 -에 -여 -역 -연 -염 -엽 -영 -옆 -예 -옛 -오 -온 -올 -옷 -옹 -와 -왔 -왜 -요 -욕 -용 -우 -운 -울 -웃 -워 -원 -월 -웠 -위 -윙 -유 -육 -윤 -으 -은 -을 -음 -응 -의 -이 -익 -인 -일 -읽 -임 -입 -있 -자 -작 -잔 -잖 -잘 -잡 -잤 -장 -재 -저 -전 -점 -정 -제 -져 -졌 -조 -족 -좀 -종 -좋 -죠 -주 -준 -줄 -중 -줘 -즈 -즐 -즘 -지 -진 -집 -짜 -짝 -쩌 -쪼 -쪽 -쫌 -쭈 -쯔 -찌 -찍 -차 -착 -찾 -책 -처 -천 -철 -체 -쳐 -쳤 -초 -촌 -추 -출 -춤 -춥 -춰 -치 -친 -칠 -침 -칩 -칼 -커 -켓 -코 -콩 -쿠 -퀴 -크 -큰 -큽 -키 -킨 -타 -태 -터 -턴 -털 -테 -토 -통 -투 -트 -특 -튼 -틀 -티 -팀 -파 -팔 -패 -페 -펜 -펭 -평 -포 -폭 -표 -품 -풍 -프 -플 -피 -필 -하 -학 -한 -할 -함 -합 -항 -해 -햇 -했 -행 -허 -험 -형 -혜 -호 -혼 -홀 -화 -회 -획 -후 -휴 -흐 -흔 -희 -히 -힘 -ﷺ -ﷻ -! -, -? -� -𠮶 diff --git a/egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py b/egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py deleted file mode 100644 index 9904901f0..000000000 --- a/egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 author: Yuekai Zhang -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import gzip -import json -import logging - -import s3tokenizer -from lhotse import CutSet, load_manifest_lazy -from tqdm import tqdm - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--manifest-dir", - type=str, - default="data/fbank", - help="Directory to store the manifest files", - ) - - parser.add_argument( - "--jsonl-prefix", - type=str, - default="wenetspeech4tts_cuts_valid", - help="The training subset for wenetspeech.", - ) - - parser.add_argument( - "--tokens-path", - type=str, - default="./s3_tokens_valid/wenetspeech4tts_valid.json", - help="json file containing the speech tokens", - ) - - return parser - - -def get_speech_tokens(tokens_path): - id2tokens = {} - with open(tokens_path, "r") as fin: - for line in fin: - line = json.loads(line) - id2tokens[line["key"]] = " ".join(map(str, line["code"])) - return id2tokens - - -def attach_manifest(manifest, fixed_manifest_path, id2tokens): - with CutSet.open_writer(fixed_manifest_path) as manifest_writer: - fixed_item = 0 - for i, cut in enumerate(tqdm(manifest)): - cut_id = cut.supervisions[0].id - if cut_id in id2tokens: - code = id2tokens[cut_id] - cut.supervisions[0].custom = { - **cut.supervisions[0].custom, - **{"speech_tokens": code}, - } - else: - print(f"cut_id {cut_id} not in id2tokens") - fixed_item += 1 - manifest_writer.write(cut) - logging.info(f"Fixed {fixed_item} items in the manifest") - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - manifest_path = args.manifest_dir + "/" + f"{args.jsonl_prefix}.jsonl.gz" - attached_manifest_path = ( - args.manifest_dir + "/" + f"{args.jsonl_prefix}_attached_cosyvoice_v2.jsonl.gz" - ) - logging.info(f"Loading manifest from {manifest_path}") - cuts_manifest = load_manifest_lazy(manifest_path) - logging.info(f"Loading manifest from {manifest_path} done") - id2tokens = get_speech_tokens(args.tokens_path) - logging.info(f"Loaded id2tokens with {len(id2tokens)} entries") - - attach_manifest(cuts_manifest, attached_manifest_path, id2tokens) - logging.info( - f"Manifest with speech tokens attached is saved to {attached_manifest_path}" - ) - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech4tts/TTS/local/audio.py b/egs/wenetspeech4tts/TTS/local/audio.py deleted file mode 100644 index b643e3de0..000000000 --- a/egs/wenetspeech4tts/TTS/local/audio.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2024 NVIDIA CORPORATION. -# Licensed under the MIT license. - -# Adapted from https://github.com/jik876/hifi-gan under the MIT license. -# LICENSE is in incl_licenses directory. - -import math -import os -import pathlib -import random -from typing import List, Optional, Tuple - -import librosa -import numpy as np -import torch -import torch.utils.data -from librosa.filters import mel as librosa_mel_fn -from tqdm import tqdm - -# from env import AttrDict - -MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) - - -def dynamic_range_compression(x, C=1, clip_val=1e-5): - return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) - - -def dynamic_range_decompression(x, C=1): - return np.exp(x) / C - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - return dynamic_range_compression_torch(magnitudes) - - -def spectral_de_normalize_torch(magnitudes): - return dynamic_range_decompression_torch(magnitudes) - - -mel_basis_cache = {} -hann_window_cache = {} - - -def mel_spectrogram( - y: torch.Tensor, - n_fft: int = 1024, - num_mels: int = 100, - sampling_rate: int = 24_000, - hop_size: int = 256, - win_size: int = 1024, - fmin: int = 0, - fmax: int = None, - center: bool = False, -) -> torch.Tensor: - """ - Calculate the mel spectrogram of an input signal. - This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). - - Args: - y (torch.Tensor): Input signal. - n_fft (int): FFT size. - num_mels (int): Number of mel bins. - sampling_rate (int): Sampling rate of the input signal. - hop_size (int): Hop size for STFT. - win_size (int): Window size for STFT. - fmin (int): Minimum frequency for mel filterbank. - fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn - center (bool): Whether to pad the input to center the frames. Default is False. - - Returns: - torch.Tensor: Mel spectrogram. - """ - if torch.min(y) < -1.0: - print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") - if torch.max(y) > 1.0: - print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") - - device = y.device - key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" - - if key not in mel_basis_cache: - mel = librosa_mel_fn( - sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax - ) - mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) - hann_window_cache[key] = torch.hann_window(win_size).to(device) - - mel_basis = mel_basis_cache[key] - hann_window = hann_window_cache[key] - - padding = (n_fft - hop_size) // 2 - y = torch.nn.functional.pad( - y.unsqueeze(1), (padding, padding), mode="reflect" - ).squeeze(1) - - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window, - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) - - mel_spec = torch.matmul(mel_basis, spec) - mel_spec = spectral_normalize_torch(mel_spec) - - return mel_spec diff --git a/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py b/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py deleted file mode 100755 index 5292c75ad..000000000 --- a/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py +++ /dev/null @@ -1,218 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the LJSpeech dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from fbank import MatchaFbank, MatchaFbankConfig -from lhotse import CutSet, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--num-jobs", - type=int, - default=1, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--src-dir", - type=Path, - default=Path("data/manifests"), - help="Path to the manifest files", - ) - - parser.add_argument( - "--output-dir", - type=Path, - default=Path("data/fbank"), - help="Path to the tokenized files", - ) - - parser.add_argument( - "--dataset-parts", - type=str, - default="Basic", - help="Space separated dataset parts", - ) - - parser.add_argument( - "--prefix", - type=str, - default="wenetspeech4tts", - help="prefix of the manifest file", - ) - - parser.add_argument( - "--suffix", - type=str, - default="jsonl.gz", - help="suffix of the manifest file", - ) - - parser.add_argument( - "--split", - type=int, - default=100, - help="Split the cut_set into multiple parts", - ) - - parser.add_argument( - "--resample-to-24kHz", - default=True, - help="Resample the audio to 24kHz", - ) - - parser.add_argument( - "--extractor", - type=str, - choices=["bigvgan", "hifigan"], - default="bigvgan", - help="The type of extractor to use", - ) - return parser - - -def compute_fbank(args): - src_dir = Path(args.src_dir) - output_dir = Path(args.output_dir) - Path(args.output_dir).mkdir(parents=True, exist_ok=True) - - num_jobs = min(args.num_jobs, os.cpu_count()) - dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ") - - logging.info(f"num_jobs: {num_jobs}") - logging.info(f"src_dir: {src_dir}") - logging.info(f"output_dir: {output_dir}") - logging.info(f"dataset_parts: {dataset_parts}") - if args.extractor == "bigvgan": - config = MatchaFbankConfig( - n_fft=1024, - n_mels=100, - sampling_rate=24_000, - hop_length=256, - win_length=1024, - f_min=0, - f_max=None, - ) - elif args.extractor == "hifigan": - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=22050, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - else: - raise NotImplementedError(f"Extractor {args.extractor} is not implemented") - - extractor = MatchaFbank(config) - - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=args.src_dir, - prefix=args.prefix, - suffix=args.suffix, - types=["recordings", "supervisions", "cuts"], - ) - - with get_executor() as ex: - for partition, m in manifests.items(): - logging.info( - f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" - ) - try: - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - except Exception: - cut_set = m["cuts"] - - if args.split > 1: - cut_sets = cut_set.split(args.split) - else: - cut_sets = [cut_set] - - for idx, part in enumerate(cut_sets): - if args.split > 1: - storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}" - else: - storage_path = ( - f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}" - ) - - if args.resample_to_24kHz: - part = part.resample(24000) - - with torch.no_grad(): - part = part.compute_and_store_features( - extractor=extractor, - storage_path=storage_path, - num_jobs=num_jobs if ex is None else 64, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - - if args.split > 1: - cuts_filename = ( - f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}" - ) - else: - cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}" - - part.to_file(f"{args.output_dir}/{cuts_filename}") - logging.info(f"Saved {cuts_filename}") - - -if __name__ == "__main__": - # Torch's multithreaded behavior needs to be disabled or - # it wastes a lot of CPU and slow things down. - # Do this outside of main() in case it needs to take effect - # even when we are not invoking the main (e.g. when spawning subprocesses). - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_parser().parse_args() - compute_fbank(args) diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py deleted file mode 100755 index 7de2c6202..000000000 --- a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py +++ /dev/null @@ -1,621 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 (authors: Feiteng Li) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Phonemize Text and EnCodec Audio. - -Usage example: - python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ - --text-extractor ${text_extractor} \ - --audio-extractor ${audio_extractor} \ - --batch-duration 2500 --prefix "wenetspeech4tts" \ - --src-dir "data/manifests" --split 100 \ - --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" - -""" -import argparse -import logging -import os -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.multiprocessing -from encodec import EncodecModel -from encodec.utils import convert_audio -from lhotse import CutSet, NumpyHdf5Writer -from lhotse.features import FeatureExtractor -from lhotse.recipes.utils import read_manifests_if_cached -from lhotse.utils import Seconds, compute_num_frames -from phonemizer.backend import EspeakBackend -from phonemizer.backend.espeak.language_switch import LanguageSwitch -from phonemizer.backend.espeak.words_mismatch import WordMismatch -from phonemizer.punctuation import Punctuation -from phonemizer.separator import Separator -from tqdm.auto import tqdm - -from icefall.utils import get_executor - -try: - from pypinyin import Style, pinyin - from pypinyin.style._utils import get_finals, get_initials -except Exception: - pass - - -import re -from typing import Pattern - -import numpy as np -from k2 import SymbolTable - -# from valle.data import ( -# AudioTokenConfig, -# AudioTokenExtractor, -# TextTokenizer, -# tokenize_text, -# ) -# from valle.data.fbank import get_fbank_extractor -# from valle.utils import SymbolTable - -os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" - - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--src-dir", - type=Path, - default=Path("data/manifests"), - help="Path to the manifest files", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=Path("data/tokenized"), - help="Path to the tokenized files", - ) - parser.add_argument( - "--text-extractor", - type=str, - default="espeak", - help="espeak or pypinyin or pypinyin_initials_finals", - ) - parser.add_argument( - "--audio-extractor", - type=str, - default="Encodec", - help="Encodec or Fbank", - ) - parser.add_argument( - "--dataset-parts", - type=str, - default="dev-clean test-clean", - help="Space separated dataset parts", - ) - parser.add_argument( - "--prefix", - type=str, - default="libritts", - help="prefix of the manifest file", - ) - parser.add_argument( - "--suffix", - type=str, - default="jsonl.gz", - help="suffix of the manifest file", - ) - parser.add_argument( - "--batch-duration", - type=float, - default=400.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", - ) - parser.add_argument( - "--split", - type=int, - default=1, - help="Split the cut_set into multiple parts", - ) - - return parser.parse_args() - - -class PypinyinBackend: - """PypinyinBackend for Chinese. Most codes is referenced from espnet. - There are two types pinyin or initials_finals, one is - just like "ni1 hao3", the other is like "n i1 h ao3". - """ - - def __init__( - self, - backend="initials_finals", - punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), - ) -> None: - self.backend = backend - self.punctuation_marks = punctuation_marks - - def phonemize( - self, text: List[str], separator: Separator, strip=True, njobs=1 - ) -> List[str]: - assert isinstance(text, List) - phonemized = [] - for _text in text: - _text = re.sub(" +", " ", _text.strip()) - _text = _text.replace(" ", separator.word) - phones = [] - if self.backend == "pypinyin": - for n, py in enumerate( - pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) - ): - if all([c in self.punctuation_marks for c in py[0]]): - if len(phones): - assert phones[-1] == separator.syllable - phones.pop(-1) - - phones.extend(list(py[0])) - else: - phones.extend([py[0], separator.syllable]) - elif self.backend == "pypinyin_initials_finals": - for n, py in enumerate( - pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) - ): - if all([c in self.punctuation_marks for c in py[0]]): - if len(phones): - assert phones[-1] == separator.syllable - phones.pop(-1) - phones.extend(list(py[0])) - else: - if py[0][-1].isalnum(): - initial = get_initials(py[0], strict=False) - if py[0][-1].isdigit(): - final = get_finals(py[0][:-1], strict=False) + py[0][-1] - else: - final = get_finals(py[0], strict=False) - phones.extend( - [ - initial, - separator.phone, - final, - separator.syllable, - ] - ) - else: - assert ValueError - else: - raise NotImplementedError - phonemized.append( - "".join(phones).rstrip(f"{separator.word}{separator.syllable}") - ) - return phonemized - - -class TextTokenizer: - """Phonemize Text.""" - - def __init__( - self, - language="en-us", - backend="espeak", - separator=Separator(word="_", syllable="-", phone="|"), - preserve_punctuation=True, - punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), - with_stress: bool = False, - tie: Union[bool, str] = False, - language_switch: LanguageSwitch = "keep-flags", - words_mismatch: WordMismatch = "ignore", - ) -> None: - if backend == "espeak": - phonemizer = EspeakBackend( - language, - punctuation_marks=punctuation_marks, - preserve_punctuation=preserve_punctuation, - with_stress=with_stress, - tie=tie, - language_switch=language_switch, - words_mismatch=words_mismatch, - ) - elif backend in ["pypinyin", "pypinyin_initials_finals"]: - phonemizer = PypinyinBackend( - backend=backend, - punctuation_marks=punctuation_marks + separator.word, - ) - else: - raise NotImplementedError(f"{backend}") - - self.backend = phonemizer - self.separator = separator - - def to_list(self, phonemized: str) -> List[str]: - fields = [] - for word in phonemized.split(self.separator.word): - # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. - pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) - fields.extend( - [p for p in pp if p != self.separator.phone] + [self.separator.word] - ) - assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( - self.separator.phone - ) - return fields[:-1] - - def __call__(self, text, strip=True) -> List[List[str]]: - if isinstance(text, str): - text = [text] - - phonemized = self.backend.phonemize( - text, separator=self.separator, strip=strip, njobs=1 - ) - return [self.to_list(p) for p in phonemized] - - -def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: - phonemes = tokenizer([text.strip()]) - return phonemes[0] # k2symbols - - -def remove_encodec_weight_norm(model): - from encodec.modules import SConv1d - from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock - from torch.nn.utils import remove_weight_norm - - encoder = model.encoder.model - for key in encoder._modules: - if isinstance(encoder._modules[key], SEANetResnetBlock): - remove_weight_norm(encoder._modules[key].shortcut.conv.conv) - block_modules = encoder._modules[key].block._modules - for skey in block_modules: - if isinstance(block_modules[skey], SConv1d): - remove_weight_norm(block_modules[skey].conv.conv) - elif isinstance(encoder._modules[key], SConv1d): - remove_weight_norm(encoder._modules[key].conv.conv) - - decoder = model.decoder.model - for key in decoder._modules: - if isinstance(decoder._modules[key], SEANetResnetBlock): - remove_weight_norm(decoder._modules[key].shortcut.conv.conv) - block_modules = decoder._modules[key].block._modules - for skey in block_modules: - if isinstance(block_modules[skey], SConv1d): - remove_weight_norm(block_modules[skey].conv.conv) - elif isinstance(decoder._modules[key], SConvTranspose1d): - remove_weight_norm(decoder._modules[key].convtr.convtr) - elif isinstance(decoder._modules[key], SConv1d): - remove_weight_norm(decoder._modules[key].conv.conv) - - -class AudioTokenizer: - """EnCodec audio.""" - - def __init__( - self, - device: Any = None, - ) -> None: - # Instantiate a pretrained EnCodec model - model = EncodecModel.encodec_model_24khz() - model.set_target_bandwidth(6.0) - remove_encodec_weight_norm(model) - - if not device: - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda:0") - - self._device = device - - self.codec = model.to(device) - self.sample_rate = model.sample_rate - self.channels = model.channels - - @property - def device(self): - return self._device - - def encode(self, wav: torch.Tensor) -> torch.Tensor: - return self.codec.encode(wav.to(self.device)) - - def decode(self, frames: torch.Tensor) -> torch.Tensor: - return self.codec.decode(frames) - - -@dataclass -class AudioTokenConfig: - frame_shift: Seconds = 320.0 / 24000 - num_quantizers: int = 8 - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - @staticmethod - def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": - return AudioTokenConfig(**data) - - -class AudioTokenExtractor(FeatureExtractor): - name = "encodec" - config_type = AudioTokenConfig - - def __init__(self, config: Optional[Any] = None): - super(AudioTokenExtractor, self).__init__(config) - self.tokenizer = AudioTokenizer() - - def extract( - self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int - ) -> np.ndarray: - if not isinstance(samples, torch.Tensor): - samples = torch.from_numpy(samples) - if sampling_rate != self.tokenizer.sample_rate: - samples = convert_audio( - samples, - sampling_rate, - self.tokenizer.sample_rate, - self.tokenizer.channels, - ) - if len(samples.shape) == 2: - samples = samples.unsqueeze(0) - else: - raise ValueError() - - device = self.tokenizer.device - encoded_frames = self.tokenizer.encode(samples.detach().to(device)) - codes = encoded_frames[0][0] # [B, n_q, T] - if True: - duration = round(samples.shape[-1] / sampling_rate, ndigits=12) - expected_num_frames = compute_num_frames( - duration=duration, - frame_shift=self.frame_shift, - sampling_rate=sampling_rate, - ) - assert abs(codes.shape[-1] - expected_num_frames) <= 1 - codes = codes[..., :expected_num_frames] - return codes.cpu().squeeze(0).permute(1, 0).numpy() - - @property - def frame_shift(self) -> Seconds: - return self.config.frame_shift - - def feature_dim(self, sampling_rate: int) -> int: - return self.config.num_quantizers - - def pad_tensor_list(self, tensor_list, device, padding_value=0): - lengths = [tensor.shape[0] for tensor in tensor_list] - tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] - padded_tensor = torch.nn.utils.rnn.pad_sequence( - tensor_list, batch_first=True, padding_value=padding_value - ) - return padded_tensor, lengths - - def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: - samples = [wav.squeeze() for wav in samples] - device = self.tokenizer.device - samples, lengths = self.pad_tensor_list(samples, device) - samples = samples.unsqueeze(1) - - if not isinstance(samples, torch.Tensor): - samples = torch.from_numpy(samples) - if len(samples.shape) != 3: - raise ValueError() - if sampling_rate != self.tokenizer.sample_rate: - samples = [ - convert_audio( - wav, - sampling_rate, - self.tokenizer.sample_rate, - self.tokenizer.channels, - ) - for wav in samples - ] - samples = torch.stack(samples, 0) # convert samples from list to tensor - # Extract discrete codes from EnCodec - with torch.no_grad(): - encoded_frames = self.tokenizer.encode(samples.detach().to(device)) - encoded_frames = encoded_frames[0][0] # [B, n_q, T] - batch_codes = [] - for b, length in enumerate(lengths): - codes = encoded_frames[b] - duration = round(length / sampling_rate, ndigits=12) - expected_num_frames = compute_num_frames( - duration=duration, - frame_shift=self.frame_shift, - sampling_rate=sampling_rate, - ) - batch_codes.append(codes[..., :expected_num_frames]) - return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] - - -def main(): - args = get_args() - - dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip() - if dataset_parts == "all": # LibriTTS - dataset_parts = [ - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ] - else: - dataset_parts = dataset_parts.replace("-p", "").strip().split(" ") - - assert len(dataset_parts) >= 1 - - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=args.src_dir, - prefix=args.prefix, - suffix=args.suffix, - types=["recordings", "supervisions", "cuts"], - ) - - text_tokenizer = None - if args.text_extractor: - text_tokenizer = TextTokenizer(backend=args.text_extractor) - - audio_extractor = None - if args.audio_extractor: - if args.audio_extractor == "Encodec": - audio_extractor = AudioTokenExtractor(AudioTokenConfig()) - else: - raise NotImplementedError(f"{args.audio_extractor}") - - Path(args.output_dir).mkdir(parents=True, exist_ok=True) - unique_symbols = set() - num_jobs = min(32, os.cpu_count()) - logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}") - - prefix = args.prefix - if prefix and not prefix.endswith("_"): - prefix = f"{prefix}_" - with get_executor() as ex: - for partition, m in manifests.items(): - logging.info( - f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" - ) - try: - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - except Exception: - cut_set = m["cuts"] - - # Split cut_set if split > 1 - split = 1 - if args.split > 1: - cut_sets = cut_set.split(args.split) - split = args.split - else: - cut_sets = [cut_set] - - for idx, part in enumerate(cut_sets): - if args.audio_extractor: - if args.audio_extractor == "Encodec": - if split > 1: - storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}" - else: - storage_path = ( - f"{args.output_dir}/{args.prefix}_encodec_{partition}" - ) - else: - if split > 1: - storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}" - else: - storage_path = ( - f"{args.output_dir}/{args.prefix}_fbank_{partition}" - ) - - if args.prefix.lower() in [ - "ljspeech", - "aishell", - "baker", - "wenetspeech4tts", - ]: - part = part.resample(24000) - assert args.prefix.lower() in [ - "ljspeech", - "aishell", - "baker", - "wenetspeech4tts", - "libritts", - "libritts-r", - ] - with torch.no_grad(): - if ( - torch.cuda.is_available() - and args.audio_extractor == "Encodec" - ): - part = part.compute_and_store_features_batch( - extractor=audio_extractor, - storage_path=storage_path, - num_workers=num_jobs, - batch_duration=args.batch_duration, - collate=False, - overwrite=True, - storage_type=NumpyHdf5Writer, - ) - else: - part = part.compute_and_store_features( - extractor=audio_extractor, - storage_path=storage_path, - num_jobs=num_jobs if ex is None else 64, - executor=ex, - storage_type=NumpyHdf5Writer, - ) - - # TextTokenizer - if args.text_extractor: - for c in tqdm(part): - if args.prefix == "ljspeech": - text = c.supervisions[0].custom["normalized_text"] - text = text.replace(""", '"').replace(""", '"') - phonemes = tokenize_text(text_tokenizer, text=text) - elif args.prefix in [ - "aishell", - "aishell2", - "wenetspeech4tts", - "libritts", - "libritts-r", - ]: - phonemes = tokenize_text( - text_tokenizer, text=c.supervisions[0].text - ) - if c.supervisions[0].custom is None: - c.supervisions[0].custom = {} - c.supervisions[0].normalized_text = c.supervisions[0].text - else: - raise NotImplementedError(f"{args.prefix}") - unique_symbols.update(phonemes) - c.tokens = phonemes - assert c.supervisions[ - 0 - ].normalized_text, "normalized_text is None" - - # Save each part with an index if split > 1 - if split > 1: - cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}" - else: - cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}" - - part.to_file(f"{args.output_dir}/{cuts_filename}") - logging.info(f"Saved {cuts_filename}") - - if args.text_extractor: - unique_phonemes = SymbolTable() - for s in sorted(list(unique_symbols)): - unique_phonemes.add(s) - logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}") - - unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols" - unique_phonemes.to_file(unique_phonemes_file) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech4tts/TTS/local/compute_wer.sh b/egs/wenetspeech4tts/TTS/local/compute_wer.sh deleted file mode 100644 index 283546383..000000000 --- a/egs/wenetspeech4tts/TTS/local/compute_wer.sh +++ /dev/null @@ -1,26 +0,0 @@ -wav_dir=$1 -wav_files=$(ls $wav_dir/*.wav) -# if wav_files is empty, then exit -if [ -z "$wav_files" ]; then - exit 1 -fi -label_file=$2 -model_path=local/sherpa-onnx-paraformer-zh-2023-09-14 - -if [ ! -d $model_path ]; then - pip install sherpa-onnx - wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 - tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local -fi - -python3 local/offline-decode-files.py \ - --tokens=$model_path/tokens.txt \ - --paraformer=$model_path/model.int8.onnx \ - --num-threads=2 \ - --decoding-method=greedy_search \ - --debug=false \ - --sample-rate=24000 \ - --log-dir $wav_dir \ - --feature-dim=80 \ - --label $label_file \ - $wav_files diff --git a/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py b/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py deleted file mode 100755 index f967dfd2b..000000000 --- a/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# Copyright 2023 (authors: Feiteng Li) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in the manifests. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. -""" - -import argparse -from pathlib import Path - -from lhotse import load_manifest_lazy - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/tokenized"), - help="Path to the tokenized manifests.", - ) - return parser.parse_args() - - -def main(): - args = get_args() - manifest_dir = args.manifest_dir or Path("data/tokenized") - for part in ["train", "dev", "test"]: - print(f"## {part}") - cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz") - cuts.describe() - print("\n") - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech4tts/TTS/local/fbank.py b/egs/wenetspeech4tts/TTS/local/fbank.py deleted file mode 120000 index 3cfb7fe3f..000000000 --- a/egs/wenetspeech4tts/TTS/local/fbank.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/local/offline-decode-files.py b/egs/wenetspeech4tts/TTS/local/offline-decode-files.py deleted file mode 100755 index fa6cbdb3e..000000000 --- a/egs/wenetspeech4tts/TTS/local/offline-decode-files.py +++ /dev/null @@ -1,495 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright (c) 2023 by manyeyes -# Copyright (c) 2023 Xiaomi Corporation - -""" -This file demonstrates how to use sherpa-onnx Python API to transcribe -file(s) with a non-streaming model. - -(1) For paraformer - - ./python-api-examples/offline-decode-files.py \ - --tokens=/path/to/tokens.txt \ - --paraformer=/path/to/paraformer.onnx \ - --num-threads=2 \ - --decoding-method=greedy_search \ - --debug=false \ - --sample-rate=16000 \ - --feature-dim=80 \ - /path/to/0.wav \ - /path/to/1.wav - -(2) For transducer models from icefall - - ./python-api-examples/offline-decode-files.py \ - --tokens=/path/to/tokens.txt \ - --encoder=/path/to/encoder.onnx \ - --decoder=/path/to/decoder.onnx \ - --joiner=/path/to/joiner.onnx \ - --num-threads=2 \ - --decoding-method=greedy_search \ - --debug=false \ - --sample-rate=16000 \ - --feature-dim=80 \ - /path/to/0.wav \ - /path/to/1.wav - -(3) For CTC models from NeMo - -python3 ./python-api-examples/offline-decode-files.py \ - --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ - --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ - --num-threads=2 \ - --decoding-method=greedy_search \ - --debug=false \ - ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ - ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ - ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav - -(4) For Whisper models - -python3 ./python-api-examples/offline-decode-files.py \ - --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ - --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ - --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ - --whisper-task=transcribe \ - --num-threads=1 \ - ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ - ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ - ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav - -(5) For CTC models from WeNet - -python3 ./python-api-examples/offline-decode-files.py \ - --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ - --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ - ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ - ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ - ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav - -(6) For tdnn models of the yesno recipe from icefall - -python3 ./python-api-examples/offline-decode-files.py \ - --sample-rate=8000 \ - --feature-dim=23 \ - --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ - --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ - ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ - ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ - ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav - -Please refer to -https://k2-fsa.github.io/sherpa/onnx/index.html -to install sherpa-onnx and to download non-streaming pre-trained models -used in this file. -""" -import argparse -import time -import wave -from pathlib import Path -from typing import List, Tuple - -import numpy as np -import sherpa_onnx -import soundfile as sf - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tokens", - type=str, - help="Path to tokens.txt", - ) - - parser.add_argument( - "--hotwords-file", - type=str, - default="", - help=""" - The file containing hotwords, one words/phrases per line, like - HELLO WORLD - 你好世界 - """, - ) - - parser.add_argument( - "--hotwords-score", - type=float, - default=1.5, - help=""" - The hotword score of each token for biasing word/phrase. Used only if - --hotwords-file is given. - """, - ) - - parser.add_argument( - "--modeling-unit", - type=str, - default="", - help=""" - The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. - Used only when hotwords-file is given. - """, - ) - - parser.add_argument( - "--bpe-vocab", - type=str, - default="", - help=""" - The path to the bpe vocabulary, the bpe vocabulary is generated by - sentencepiece, you can also export the bpe vocabulary through a bpe model - by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given - and modeling-unit is bpe or cjkchar+bpe. - """, - ) - - parser.add_argument( - "--encoder", - default="", - type=str, - help="Path to the encoder model", - ) - - parser.add_argument( - "--decoder", - default="", - type=str, - help="Path to the decoder model", - ) - - parser.add_argument( - "--joiner", - default="", - type=str, - help="Path to the joiner model", - ) - - parser.add_argument( - "--paraformer", - default="", - type=str, - help="Path to the model.onnx from Paraformer", - ) - - parser.add_argument( - "--nemo-ctc", - default="", - type=str, - help="Path to the model.onnx from NeMo CTC", - ) - - parser.add_argument( - "--wenet-ctc", - default="", - type=str, - help="Path to the model.onnx from WeNet CTC", - ) - - parser.add_argument( - "--tdnn-model", - default="", - type=str, - help="Path to the model.onnx for the tdnn model of the yesno recipe", - ) - - parser.add_argument( - "--num-threads", - type=int, - default=1, - help="Number of threads for neural network computation", - ) - - parser.add_argument( - "--whisper-encoder", - default="", - type=str, - help="Path to whisper encoder model", - ) - - parser.add_argument( - "--whisper-decoder", - default="", - type=str, - help="Path to whisper decoder model", - ) - - parser.add_argument( - "--whisper-language", - default="", - type=str, - help="""It specifies the spoken language in the input audio file. - Example values: en, fr, de, zh, jp. - Available languages for multilingual models can be found at - https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 - If not specified, we infer the language from the input audio file. - """, - ) - - parser.add_argument( - "--whisper-task", - default="transcribe", - choices=["transcribe", "translate"], - type=str, - help="""For multilingual models, if you specify translate, the output - will be in English. - """, - ) - - parser.add_argument( - "--whisper-tail-paddings", - default=-1, - type=int, - help="""Number of tail padding frames. - We have removed the 30-second constraint from whisper, so you need to - choose the amount of tail padding frames by yourself. - Use -1 to use a default value for tail padding. - """, - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - parser.add_argument( - "--debug", - type=bool, - default=False, - help="True to show debug messages", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="""Sample rate of the feature extractor. Must match the one - expected by the model. Note: The input sound files can have a - different sample rate from this argument.""", - ) - - parser.add_argument( - "--feature-dim", - type=int, - default=80, - help="Feature dimension. Must match the one expected by the model", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to decode. Each file must be of WAVE" - "format with a single channel, and each sample has 16-bit, " - "i.e., int16_t. " - "The sample rate of the file can be arbitrary and does not need to " - "be 16 kHz", - ) - - parser.add_argument( - "--name", - type=str, - default="", - help="The directory containing the input sound files to decode", - ) - - parser.add_argument( - "--log-dir", - type=str, - default="", - help="The directory containing the input sound files to decode", - ) - - parser.add_argument( - "--label", - type=str, - default=None, - help="wav_base_name label", - ) - return parser.parse_args() - - -def assert_file_exists(filename: str): - assert Path(filename).is_file(), ( - f"{filename} does not exist!\n" - "Please refer to " - "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" - ) - - -def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: - """ - Args: - wave_filename: - Path to a wave file. It should be single channel and can be of type - 32-bit floating point PCM. Its sample rate does not need to be 24kHz. - - Returns: - Return a tuple containing: - - A 1-D array of dtype np.float32 containing the samples, - which are normalized to the range [-1, 1]. - - Sample rate of the wave file. - """ - - samples, sample_rate = sf.read(wave_filename, dtype="float32") - assert ( - samples.ndim == 1 - ), f"Expected single channel, but got {samples.ndim} channels." - - samples_float32 = samples.astype(np.float32) - - return samples_float32, sample_rate - - -def normalize_text_alimeeting(text: str) -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - - -def main(): - args = get_args() - assert_file_exists(args.tokens) - assert args.num_threads > 0, args.num_threads - - assert len(args.nemo_ctc) == 0, args.nemo_ctc - assert len(args.wenet_ctc) == 0, args.wenet_ctc - assert len(args.whisper_encoder) == 0, args.whisper_encoder - assert len(args.whisper_decoder) == 0, args.whisper_decoder - assert len(args.tdnn_model) == 0, args.tdnn_model - - assert_file_exists(args.paraformer) - - recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( - paraformer=args.paraformer, - tokens=args.tokens, - num_threads=args.num_threads, - sample_rate=args.sample_rate, - feature_dim=args.feature_dim, - decoding_method=args.decoding_method, - debug=args.debug, - ) - - print("Started!") - start_time = time.time() - - streams, results = [], [] - total_duration = 0 - - for i, wave_filename in enumerate(args.sound_files): - assert_file_exists(wave_filename) - samples, sample_rate = read_wave(wave_filename) - duration = len(samples) / sample_rate - total_duration += duration - s = recognizer.create_stream() - s.accept_waveform(sample_rate, samples) - - streams.append(s) - if i % 10 == 0: - recognizer.decode_streams(streams) - results += [s.result.text for s in streams] - streams = [] - print(f"Processed {i} files") - # process the last batch - if streams: - recognizer.decode_streams(streams) - results += [s.result.text for s in streams] - end_time = time.time() - print("Done!") - - results_dict = {} - for wave_filename, result in zip(args.sound_files, results): - print(f"{wave_filename}\n{result}") - print("-" * 10) - wave_basename = Path(wave_filename).stem - results_dict[wave_basename] = result - - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - print(f"num_threads: {args.num_threads}") - print(f"decoding_method: {args.decoding_method}") - print(f"Wave duration: {total_duration:.3f} s") - print(f"Elapsed time: {elapsed_seconds:.3f} s") - print( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - if args.label: - from icefall.utils import store_transcripts, write_error_stats - - labels_dict = {} - with open(args.label, "r") as f: - for line in f: - # fields = line.strip().split(" ") - # fields = [item for item in fields if item] - # assert len(fields) == 4 - # prompt_text, prompt_audio, text, audio_path = fields - - fields = line.strip().split("|") - fields = [item for item in fields if item] - assert len(fields) == 4 - audio_path, prompt_text, prompt_audio, text = fields - labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text) - - final_results = [] - for key, value in results_dict.items(): - final_results.append((key, labels_dict[key], value)) - - store_transcripts( - filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results - ) - with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f: - write_error_stats(f, "test-set", final_results, enable_log=True) - - with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f: - print(f.readline()) # WER - print(f.readline()) # Detailed errors - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh deleted file mode 100755 index f1daa0e62..000000000 --- a/egs/wenetspeech4tts/TTS/prepare.sh +++ /dev/null @@ -1,165 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -stage=1 -stop_stage=4 - -dl_dir=$PWD/download - -dataset_parts="Premium" # Basic for all 7226 hours data, Premium for 945 hours subset. - -text_extractor="pypinyin_initials_finals" # default is espeak for English -audio_extractor="Encodec" # or Fbank -audio_feats_dir=data/tokenized - -. shared/parse_options.sh || exit 1 - - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "dl_dir: $dl_dir" - log "Stage 0: Download data" - huggingface-cli login - huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS - - # Extract the downloaded data: - for folder in Standard Premium Basic; do - for file in "$dl_dir/$folder"/*.tar.gz; do - tar -xzvf "$file" -C "$dl_dir/$folder" - done - done -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare wenetspeech4tts manifest" - # We assume that you have downloaded the wenetspeech4tts corpus - # to $dl_dir/wenetspeech4tts - mkdir -p data/manifests - if [ ! -e data/manifests/.wenetspeech4tts.done ]; then - lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}" - touch data/manifests/.wenetspeech4tts.done - fi -fi - - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Tokenize/Fbank wenetspeech4tts" - mkdir -p ${audio_feats_dir} - if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then - python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ - --text-extractor ${text_extractor} \ - --audio-extractor ${audio_extractor} \ - --batch-duration 2500 --prefix "wenetspeech4tts" \ - --src-dir "data/manifests" \ - --split 100 \ - --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" - cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir} - fi - touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Combine features" - if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz ]; then - pieces=$(find ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100 -name "*.jsonl.gz") - lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare wenetspeech4tts train/dev/test" - if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then - - lhotse subset --first 400 \ - ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ - ${audio_feats_dir}/cuts_dev.jsonl.gz - - lhotse subset --last 400 \ - ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ - ${audio_feats_dir}/cuts_test.jsonl.gz - - lhotse copy \ - ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ - ${audio_feats_dir}/cuts_train.jsonl.gz - - touch ${audio_feats_dir}/.wenetspeech4tts.train.done - fi - python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} -fi - -subset="Basic" -prefix="wenetspeech4tts" -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Generate fbank (used by ./f5-tts)" - mkdir -p data/fbank - if [ ! -e data/fbank/.${prefix}.done ]; then - ./local/compute_mel_feat.py --dataset-parts $subset --split 100 - touch data/fbank/.${prefix}.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)" - if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then - echo "Combining ${prefix} cuts" - pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") - lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz - fi - if [ ! -e data/fbank/.${prefix}_split.done ]; then - echo "Splitting ${prefix} cuts into train, valid and test sets" - - lhotse subset --last 800 \ - data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz - lhotse subset --first 400 \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - data/fbank/${prefix}_cuts_valid.jsonl.gz - lhotse subset --last 400 \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - data/fbank/${prefix}_cuts_test.jsonl.gz - - rm data/fbank/${prefix}_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) - lhotse subset --first $n \ - data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ - data/fbank/${prefix}_cuts_train.jsonl.gz - touch data/fbank/.${prefix}_split.done - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" - split_name=("valid" "test" "train") - for split in "${split_name[@]}"; do - echo "Processing $split" - wav_scp_file=wav_${split}.scp - output_dir="./cosy_v2_tokens_${split}" - oringinal_jsonl_file=data/fbank/${prefix}_cuts_${split}.jsonl.gz - mkdir -p $output_dir - zcat $oringinal_jsonl_file | jq -r '.recording.id + " " + .recording.sources[0].source' > $wav_scp_file - torchrun --nproc_per_node=8 --nnodes=1 \ - --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ - `which s3tokenizer` --wav_scp $wav_scp_file \ - --device "cuda" \ - --output_dir $output_dir \ - --batch_size 32 \ - --num_workers 4 \ - --model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz - - cat $output_dir/* > $output_dir/${prefix}_${split}_cosy_v2_tokens.json - python3 local/attach_speech_tokens.py --jsonl-prefix ${prefix}_cuts_${split} --tokens-path $output_dir/${prefix}_${split}_cosy_v2_tokens.json --manifest-dir data/fbank - done -fi diff --git a/egs/wenetspeech4tts/TTS/shared b/egs/wenetspeech4tts/TTS/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/wenetspeech4tts/TTS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py deleted file mode 120000 index e70ee319a..000000000 --- a/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py deleted file mode 100644 index d98abb731..000000000 --- a/egs/wenetspeech4tts/TTS/valle/infer.py +++ /dev/null @@ -1,285 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 (authors: Feiteng Li) -# Copyright 2024 (authors: Yuekai Zhang) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script is used to synthesize speech from text prompts and audio prompts. -Usage example: - python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \ - --checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ - --text-prompts "KNOT one point one five miles per hour." \ - --audio-prompts ./prompts/8463_294825_000043_000000.wav \ - --text "To get up and running quickly just follow the steps below." - - top_p=1.0 - python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ - --top-k -1 --temperature 1.0 \ - --text ./aishell3.txt \ - --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ - --text-extractor pypinyin_initials_finals --top-p ${top_p} - -""" -import argparse -import logging -import os -from pathlib import Path - -os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" - -import torch -import torchaudio -from compute_neural_codec_and_prepare_text_tokens import ( - AudioTokenizer, - TextTokenizer, - tokenize_text, -) -from encodec.utils import convert_audio -from k2 import symbol_table -from tokenizer import get_text_token_collater -from valle import VALLE - -from icefall.utils import AttributeDict, str2bool - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--text-prompts", - type=str, - default="", - help="Text prompts which are separated by |.", - ) - - parser.add_argument( - "--audio-prompts", - type=str, - default="", - help="Audio prompts which are separated by | and should be aligned with --text-prompts.", - ) - - parser.add_argument( - "--text", - type=str, - default="", - help="prompt text\t prompt audio\ttarget text\ttarget audio", - ) - - parser.add_argument( - "--text-extractor", - type=str, - default="espeak", - help="espeak or pypinyin or pypinyin_initials_finals", - ) - - parser.add_argument( - "--checkpoint", - type=str, - default="./valle/exp/checkpoint-100000.pt", - help="Path to the saved checkpoint.", - ) - - parser.add_argument( - "--output-dir", - type=Path, - default=Path("infer/demo"), - help="Path to the tokenized files.", - ) - - parser.add_argument( - "--top-k", - type=int, - default=-100, - help="Whether AR Decoder do top_k(if > 0) sampling.", - ) - - parser.add_argument( - "--top-p", - type=float, - default=1.0, - help="Whether AR Decoder do top_p(if > 0) sampling.", - ) - - parser.add_argument( - "--temperature", - type=float, - default=1.0, - help="The temperature of AR Decoder top_k sampling.", - ) - - parser.add_argument( - "--repetition-aware-sampling", - type=str2bool, - default=False, - help="Whether AR Decoder do valle-2 repetition-aware sampling. https://arxiv.org/pdf/2406.05370", - ) - - return parser.parse_args() - - -def load_model(checkpoint, device): - if not checkpoint: - return None - - checkpoint = torch.load(checkpoint, map_location=device) - - params = AttributeDict(checkpoint) - model = VALLE( - params.decoder_dim, - params.nhead, - params.num_decoder_layers, - norm_first=params.norm_first, - add_prenet=params.add_prenet, - prefix_mode=params.prefix_mode, - share_embedding=params.share_embedding, - nar_scale_factor=params.scale_factor, - prepend_bos=params.prepend_bos, - num_quantizers=params.num_quantizers, - ) - - missing_keys, unexpected_keys = model.load_state_dict( - checkpoint["model"], strict=True - ) - assert not missing_keys - model.to(device) - model.eval() - - return model, params.text_tokens - - -def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): - # Load and pre-process the audio waveform - wav, sr = torchaudio.load(audio_path) - wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) - wav = wav.unsqueeze(0) - - # Extract discrete codes from EnCodec - with torch.no_grad(): - encoded_frames = tokenizer.encode(wav) - return encoded_frames - - -@torch.no_grad() -def main(): - args = get_args() - text_tokenizer = TextTokenizer(backend=args.text_extractor) - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - model, text_tokens = load_model(args.checkpoint, device) - - text_collater = get_text_token_collater(text_tokens) - - audio_tokenizer = AudioTokenizer() - - Path(args.output_dir).mkdir(parents=True, exist_ok=True) - - text_prompts = " ".join(args.text_prompts.split("|")) - - audio_prompts = [] - if args.audio_prompts: - for n, audio_file in enumerate(args.audio_prompts.split("|")): - encoded_frames = tokenize_audio(audio_tokenizer, audio_file) - if False: - samples = audio_tokenizer.decode(encoded_frames) - torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000) - - audio_prompts.append(encoded_frames[0][0]) - - assert len(args.text_prompts.split("|")) == len(audio_prompts) - audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) - audio_prompts = audio_prompts.to(device) - - if os.path.isfile(args.text): # for demos - # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py - with open(args.text) as f: - for line in f: - fields = line.strip().split(" ") - fields = [item for item in fields if item] - assert len(fields) == 4 - prompt_text, prompt_audio, text, audio_path = fields - logging.info(f"synthesize text: {text}") - text_tokens, text_tokens_lens = text_collater( - [ - tokenize_text( - text_tokenizer, text=f"{prompt_text} {text}".strip() - ) - ] - ) - _, enroll_x_lens = text_collater( - [tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())] - ) - - audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) - audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device) - - # synthesis - encoded_frames = model.inference( - text_tokens.to(device), - text_tokens_lens.to(device), - audio_prompts, - enroll_x_lens=enroll_x_lens, - top_k=args.top_k, - temperature=args.temperature, - top_p=args.top_p, - ras=args.repetition_aware_sampling, - ) - - samples = audio_tokenizer.decode( - [(encoded_frames.transpose(2, 1), None)] - ) - # store - # save audio path into args.output_dir + audio_path - audio_path = f"{args.output_dir}/{audio_path}" - # mkdir -p - os.makedirs(os.path.dirname(audio_path), exist_ok=True) - torchaudio.save(audio_path, samples[0].cpu(), 24000) - return - - for n, text in enumerate(args.text.split("|")): - logging.info(f"synthesize text: {text}") - text_tokens, text_tokens_lens = text_collater( - [tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())] - ) - - # synthesis - enroll_x_lens = None - if text_prompts: - _, enroll_x_lens = text_collater( - [tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())] - ) - encoded_frames = model.inference( - text_tokens.to(device), - text_tokens_lens.to(device), - audio_prompts, - enroll_x_lens=enroll_x_lens, - top_k=args.top_k, - temperature=args.temperature, - top_p=args.top_p, - ras=args.repetition_aware_sampling, - ) - - if audio_prompts != []: - samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) - # store - torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000) - else: # Transformer - pass - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech4tts/TTS/valle/optim.py b/egs/wenetspeech4tts/TTS/valle/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/wenetspeech4tts/TTS/valle/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/requirements.txt b/egs/wenetspeech4tts/TTS/valle/requirements.txt deleted file mode 100644 index 06958dbea..000000000 --- a/egs/wenetspeech4tts/TTS/valle/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -phonemizer==3.2.1 -git+https://github.com/facebookresearch/encodec.git \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/tokenizer.py b/egs/wenetspeech4tts/TTS/valle/tokenizer.py deleted file mode 100644 index db4f00396..000000000 --- a/egs/wenetspeech4tts/TTS/valle/tokenizer.py +++ /dev/null @@ -1,111 +0,0 @@ -from pathlib import Path -from typing import List, Tuple - -import numpy as np -import torch -from k2 import SymbolTable - - -class TextTokenCollater: - """Collate list of text tokens - - Map sentences to integers. Sentences are padded to equal length. - Beginning and end-of-sequence symbols can be added. - - Example: - >>> token_collater = TextTokenCollater(text_tokens) - >>> tokens_batch, tokens_lens = token_collater(text) - - Returns: - tokens_batch: IntTensor of shape (B, L) - B: batch dimension, number of input sentences - L: length of the longest sentence - tokens_lens: IntTensor of shape (B,) - Length of each sentence after adding and - but before padding. - """ - - def __init__( - self, - text_tokens: List[str], - add_eos: bool = True, - add_bos: bool = True, - pad_symbol: str = "", - bos_symbol: str = "", - eos_symbol: str = "", - ): - self.pad_symbol = pad_symbol - - self.add_eos = add_eos - self.add_bos = add_bos - - self.bos_symbol = bos_symbol - self.eos_symbol = eos_symbol - - unique_tokens = ( - [pad_symbol] - + ([bos_symbol] if add_bos else []) - + ([eos_symbol] if add_eos else []) - + sorted(text_tokens) - ) - - self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} - self.idx2token = [token for token in unique_tokens] - - def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: - seqs, seq_lens = [], [] - for tokens in tokens_list: - assert all([True if s in self.token2idx else False for s in tokens]) is True - seq = ( - ([self.bos_symbol] if self.add_bos else []) - + list(tokens) - + ([self.eos_symbol] if self.add_eos else []) - ) - seqs.append(seq) - seq_lens.append(len(seq)) - - max_len = max(seq_lens) - for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): - seq.extend([self.pad_symbol] * (max_len - seq_len)) - - tokens = torch.from_numpy( - np.array( - [[self.token2idx[token] for token in seq] for seq in seqs], - dtype=np.int64, - ) - ) - tokens_lens = torch.IntTensor(seq_lens) - - return tokens, tokens_lens - - def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: - tokens_seqs = [[p for p in text] for text in texts] - max_len = len(max(tokens_seqs, key=len)) - - seqs = [ - ([self.bos_symbol] if self.add_bos else []) - + list(seq) - + ([self.eos_symbol] if self.add_eos else []) - + [self.pad_symbol] * (max_len - len(seq)) - for seq in tokens_seqs - ] - - tokens_batch = torch.from_numpy( - np.array( - [[self.token2idx[token] for token in seq] for seq in seqs], - dtype=np.int64, - ) - ) - - tokens_lens = torch.IntTensor( - [len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs] - ) - - return tokens_batch, tokens_lens - - -def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: - text_tokens_path = Path(text_tokens_file) - unique_tokens = SymbolTable.from_file(text_tokens_path) - collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True) - return collater diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py deleted file mode 100755 index e9ec548f3..000000000 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ /dev/null @@ -1,1243 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo) -# Copyright 2023 (authors: Feiteng Li) -# Copyright 2024 (authors: Yuekai Zhang) -# Copyright 2024 Tsinghua University (authors: Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -world_size=8 -exp_dir=exp/valle - -## Train AR model -python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ - --share-embedding true --norm-first true --add-prenet false \ - --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ - --base-lr 0.03 --warmup-steps 200 --average-period 0 \ - --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ - --exp-dir ${exp_dir} --world-size ${world_size} - -## Train NAR model -# cd ${exp_dir} -# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 -# cd - -python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ - --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ - --share-embedding true --norm-first true --add-prenet false \ - --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ - --base-lr 0.03 --warmup-steps 200 --average-period 0 \ - --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ - --exp-dir ${exp_dir} --world-size ${world_size} -""" - -import argparse -import copy -import logging -import random -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse import CutSet -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from optim import Eden, ScaledAdam -from tokenizer import TextTokenCollater, get_text_token_collater -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import TtsDataModule -from valle import VALLE - -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--decoder-dim", - type=int, - default=1024, - help="Embedding dimension in the decoder model.", - ) - parser.add_argument( - "--nhead", - type=int, - default=16, - help="Number of attention heads in the Decoder layers.", - ) - parser.add_argument( - "--num-decoder-layers", - type=int, - default=12, - help="Number of Decoder layers.", - ) - parser.add_argument( - "--scale-factor", - type=float, - default=1.0, - help="Model scale factor which will be assigned different meanings in different models.", - ) - parser.add_argument( - "--norm-first", - type=str2bool, - default=True, - help="Pre or Post Normalization.", - ) - parser.add_argument( - "--add-prenet", - type=str2bool, - default=False, - help="Whether add PreNet after Inputs.", - ) - - parser.add_argument( - "--prefix-mode", - type=int, - default=0, - help="The mode for how to prefix VALL-E NAR Decoder, " - "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", - ) - parser.add_argument( - "--share-embedding", - type=str2bool, - default=True, - help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", - ) - parser.add_argument( - "--prepend-bos", - type=str2bool, - default=False, - help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", - ) - parser.add_argument( - "--num-quantizers", - type=int, - default=8, - help="Number of Audio/Semantic quantization layers.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="./valle/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--text-tokens", - type=str, - default="data/tokenized/unique_text_tokens.k2symbols", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--optimizer-name", - type=str, - default="ScaledAdam", - help="The optimizer.", - ) - parser.add_argument( - "--scheduler-name", - type=str, - default="Eden", - help="The scheduler.", - ) - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - parser.add_argument( - "--warmup-steps", - type=int, - default=200, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=10000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train %% save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - parser.add_argument( - "--valid-interval", - type=int, - default=10000, - help="""Run validation if batch_idx %% valid_interval is 0.""", - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=0, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--accumulate-grad-steps", - type=int, - default=1, - help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. - """, - ) - - parser.add_argument( - "--dtype", - type=str, - default="float32", - help="Training dtype: float32 bfloat16 float16.", - ) - - parser.add_argument( - "--filter-min-duration", - type=float, - default=0.0, - help="Keep only utterances with duration > this.", - ) - parser.add_argument( - "--filter-max-duration", - type=float, - default=20.0, - help="Keep only utterances with duration < this.", - ) - - parser.add_argument( - "--train-stage", - type=int, - default=0, - help="""0: train all modules, For VALL-E, support 1: AR Decoder 2: NAR Decoder(s) - """, - ) - - parser.add_argument( - "--visualize", - type=str2bool, - default=False, - help="visualize model results in eval step.", - ) - - parser.add_argument( - "--oom-check", - type=str2bool, - default=False, - help="perform OOM check on dataloader batches before starting training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 100, - "reset_interval": 200, - "valid_interval": 10000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - if isinstance(model, DDP): - raise ValueError("load_checkpoint before DDP") - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - saved_stage = saved_params.get("train_stage", 0) - if params.train_stage != saved_stage: - # switch training stage - if params.train_stage and saved_stage: # switch between 1 and 2 - params.start_epoch = 1 - params.start_batch = 0 - else: - # switch between 0 and 1/2 - assert params.num_epochs >= params.start_epoch - params.batch_idx_train = saved_params["batch_idx_train"] - - for key in ["optimizer", "grad_scaler", "sampler"]: - if key in saved_params: - saved_params.pop(key) - - # when base on stage 0, we keep scheduler - if saved_stage != 0: - for key in ["scheduler"]: - if key in saved_params: - saved_params.pop(key) - - best_train_filename = params.exp_dir / "best-train-loss.pt" - if best_train_filename.is_file(): - copyfile( - src=best_train_filename, - dst=params.exp_dir / f"best-train-loss-stage{saved_stage}.pt", - ) - - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - if best_valid_filename.is_file(): - copyfile( - src=best_valid_filename, - dst=params.exp_dir / f"best-valid-loss-stage{saved_stage}.pt", - ) - else: - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def prepare_input(batch: dict, tokenizer: TextTokenCollater, device: torch.device): - """Parse batch data""" - - features = batch["features"].to(device) - features_lens = batch["features_lens"].to(device) - if "tokens" not in batch: - raise NotImplementedError("Need to tokenize text") - # tokens = [] - # for c in batch["cuts"]: - # phonemes = tokenize_text( - # tokenizer, text=c.supervisions[0].text - # ) - # tokens.append(phonemes) - else: - tokens = batch["tokens"] - - text_tokens, text_tokens_lens = tokenizer(tokens) - text_tokens = text_tokens.to(device) - text_tokens_lens = text_tokens_lens.to(device) - - return features, features_lens, text_tokens, text_tokens_lens - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: TextTokenCollater, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - ( - audio_features, - audio_features_lens, - text_tokens, - text_tokens_lens, - ) = prepare_input(batch, tokenizer, device) - # at entry, TextTokens is (N, P) - assert text_tokens.ndim == 2 - assert audio_features.ndim == 3 - - with torch.set_grad_enabled(is_training): - predicts, loss, metrics = model( - x=text_tokens, - x_lens=text_tokens_lens, - y=audio_features, - y_lens=audio_features_lens, - train_stage=params.train_stage, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (audio_features_lens).sum().item() - info["utterances"] = text_tokens.size(0) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - for metric in metrics: - info[metric] = metrics[metric].detach().cpu().item() - del metrics - - return predicts, loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: TextTokenCollater, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - predicts, loss, loss_info = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - if world_size > 1: - tot_loss.reduce(loss.device) - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - if params.visualize: - output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") - output_dir.mkdir(parents=True, exist_ok=True) - if isinstance(model, DDP): - model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir) - else: - model.visualize(predicts, batch, tokenizer, output_dir=output_dir) - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: TextTokenCollater, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - rng: random.Random, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - rng: - Random for selecting. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - tot_loss = MetricsTracker() - iter_dl = iter(train_dl) - - dtype, enabled = torch.float32, False - if params.dtype in ["bfloat16", "bf16"]: - dtype, enabled = torch.bfloat16, True - elif params.dtype in ["float16", "fp16"]: - dtype, enabled = torch.float16, True - - batch_idx = 0 - while True: - try: - batch = next(iter_dl) - except StopIteration: - logging.info("Reaches end of dataloader.") - break - - batch_idx += 1 - - params.batch_idx_train += 1 - batch_size = len(batch["text"]) - - try: - with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): - _, loss, loss_info = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( - 1 / params.reset_interval - ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - scaler.scale(loss).backward() - if params.batch_idx_train >= params.accumulate_grad_steps: - if params.batch_idx_train % params.accumulate_grad_steps == 0: - if params.optimizer_name not in ["ScaledAdam", "Eve"]: - # Unscales the gradients of optimizer's assigned params in-place - scaler.unscale_(optimizer) - # Since the gradients of optimizer's assigned params are unscaled, clips as usual: - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - for k in range(params.accumulate_grad_steps): - if isinstance(scheduler, Eden): - scheduler.step_batch(params.batch_idx_train) - else: - scheduler.step() - - set_batch_count(model, params.batch_idx_train) - except: # noqa - display_and_save_batch(batch, params=params) - raise - - if params.average_period > 0: - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - # Perform Operation in rank 0 - if rank == 0: - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - # Perform Operation in rank 0 - if rank == 0: - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = ( - scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 - ) - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, train_loss[{loss_info}], " - f"tot_loss[{tot_loss}], " - f"batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - + ( - f", grad_scale: {cur_grad_scale}" - if params.dtype in ["float16", "fp16"] - else "" - ) - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train, - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.dtype in ["float16", "fp16"]: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if params.batch_idx_train % params.valid_interval == 0: - # Calculate validation loss in Rank 0 - model.eval() - logging.info("Computing validation loss") - with torch.cuda.amp.autocast(dtype=dtype): - valid_info = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - model.train() - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def filter_short_and_long_utterances( - cuts: CutSet, min_duration: float, max_duration: float -) -> CutSet: - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 0.6 second and 20 seconds - if c.duration < min_duration or c.duration > max_duration: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - cuts = cuts.filter(remove_short_and_long_utt) - - return cuts - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - rng = random.Random(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - if params.train_stage: - tb_writer = SummaryWriter( - log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}" - ) - else: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - torch.backends.cudnn.allow_tf32 = True - torch.backends.cuda.matmul.allow_tf32 = True - - logging.info(f"Device: {device}") - - tokenizer = get_text_token_collater(params.text_tokens) - logging.info(params) - - logging.info("About to create model") - - model = VALLE( - params.decoder_dim, - params.nhead, - params.num_decoder_layers, - norm_first=params.norm_first, - add_prenet=params.add_prenet, - prefix_mode=params.prefix_mode, - share_embedding=params.share_embedding, - nar_scale_factor=params.scale_factor, - prepend_bos=params.prepend_bos, - num_quantizers=params.num_quantizers, - ) - - with open(f"{params.exp_dir}/model.txt", "w") as f: - print(model) - print(model, file=f) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0 and params.average_period > 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - if params.train_stage: - _model = model.module if isinstance(model, DDP) else model - model_parameters = _model.stage_parameters(params.train_stage) - else: - model_parameters = model.parameters() - - if params.optimizer_name == "ScaledAdam": - optimizer = ScaledAdam( - model_parameters, - lr=params.base_lr, - clipping_scale=2.0, - ) - elif params.optimizer_name == "AdamW": - optimizer = torch.optim.AdamW( - model_parameters, - lr=params.base_lr, - betas=(0.9, 0.95), - weight_decay=1e-2, - eps=1e-8, - ) - elif params.optimizer_name == "Adam": - optimizer = torch.optim.Adam( - model_parameters, - lr=params.base_lr, - betas=(0.9, 0.95), - eps=1e-8, - ) - else: - raise NotImplementedError() - - scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) - optimizer.zero_grad() - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.inf_check: - register_inf_check_hooks(model) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - dataset = TtsDataModule(args) - train_cuts = dataset.train_cuts() - valid_cuts = dataset.dev_cuts() - - train_cuts = filter_short_and_long_utterances( - train_cuts, params.filter_min_duration, params.filter_max_duration - ) - valid_cuts = filter_short_and_long_utterances( - valid_cuts, params.filter_min_duration, params.filter_max_duration - ) - - train_dl = dataset.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - valid_dl = dataset.dev_dataloaders(valid_cuts) - - if params.oom_check: - scan_pessimistic_batches_for_oom( - model=model, - tokenizer=tokenizer, - train_dl=train_dl, - optimizer=optimizer, - params=params, - ) - - scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - if isinstance(scheduler, Eden): - scheduler.step_epoch(epoch - 1) - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - rng=rng, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - tokenizer: TextTokenCollater, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - - dtype = torch.float32 - if params.dtype in ["bfloat16", "bf16"]: - dtype = torch.bfloat16 - elif params.dtype in ["float16", "fp16"]: - dtype = torch.float16 - - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(dtype=dtype): - _, loss, _ = compute_loss( - params=params, - model=model, - tokenizer=tokenizer, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - TtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py b/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py deleted file mode 100644 index 8e34d06dc..000000000 --- a/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin,) -# Copyright 2023 (authors: Feiteng Li) -# Copyright 2024 (Author: Yuekai Zhang) -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.features.io import KaldiReader -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class TtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in TTS - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/tokenized"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--speaker-embeds", - type=Path, - default=Path("exp/xvector_nnet_1a/"), - help="Path to directory with speaker embeddings.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=4, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=False, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - group.add_argument( - "--dataset", - type=str, - default="libritts", - help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", - ) - - parser.add_argument( - "--sampling-rate", - type=int, - default=24000, - help="""Audio sampling rate.""", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=False, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - raise NotImplementedError - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - raise NotImplementedError - else: - validate = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=False, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - dev_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create valid dataloader") - dev_dl = DataLoader( - validate, - sampler=dev_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return dev_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - raise NotImplementedError - else: - test = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=False, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz") - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" - ) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py deleted file mode 100644 index 8f9b8fc3d..000000000 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ /dev/null @@ -1,1731 +0,0 @@ -# Copyright 2023 (authors: Feiteng Li) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -import numbers -import random -from functools import partial -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union - -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -from tokenizer import TextTokenCollater -from torch import Tensor -from torch.nn import Linear, Module -from torch.nn import functional as F -from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ -from torch.nn.modules.linear import NonDynamicallyQuantizableLinear -from torch.nn.parameter import Parameter -from torchmetrics.classification import MulticlassAccuracy - -from icefall.utils import make_pad_mask - -NUM_TEXT_TOKENS = 5000 -NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins - - -class PromptedFeatures: - def __init__(self, prompts, features): - self.prompts = prompts - self.features = features - - def to(self, device): - return PromptedFeatures(self.prompts.to(device), self.features.to(device)) - - def sum(self): - return self.features.sum() - - @property - def ndim(self): - return self.features.ndim - - @property - def data(self): - return (self.prompts, self.features) - - -class TokenEmbedding(nn.Module): - def __init__( - self, - dim_model: int, - vocab_size: int, - dropout: float = 0.0, - ): - super().__init__() - - self.vocab_size = vocab_size - self.dim_model = dim_model - - self.dropout = torch.nn.Dropout(p=dropout) - self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) - - @property - def weight(self) -> torch.Tensor: - return self.word_embeddings.weight - - def embedding(self, index: int) -> torch.Tensor: - return self.word_embeddings.weight[index : index + 1] - - def forward(self, x: torch.Tensor): - X = self.word_embeddings(x) - X = self.dropout(X) - - return X - - -class SinePositionalEmbedding(nn.Module): - def __init__( - self, - dim_model: int, - dropout: float = 0.0, - scale: bool = False, - alpha: bool = False, - ): - super().__init__() - self.dim_model = dim_model - self.x_scale = math.sqrt(dim_model) if scale else 1.0 - self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) - self.dropout = torch.nn.Dropout(p=dropout) - - self.reverse = False - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, 4000)) - - def extend_pe(self, x): - """Reset the positional encodings.""" - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.dim_model) - if self.reverse: - position = torch.arange( - x.size(1) - 1, -1, -1.0, dtype=torch.float32 - ).unsqueeze(1) - else: - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.dim_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.dim_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.pe = pe.to(device=x.device, dtype=x.dtype).detach() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - self.extend_pe(x) - output = x.unsqueeze(-1) if x.ndim == 2 else x - output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] - return self.dropout(output) - - -class Transpose(nn.Identity): - """(N, T, D) -> (N, D, T)""" - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return input.transpose(1, 2) - - -_shape_t = Union[int, List[int], torch.Size] - - -class MultiheadAttention(Module): - r"""Allows the model to jointly attend to information - from different representation subspaces as described in the paper: - `Attention Is All You Need `_. - - Multi-Head Attention is defined as: - - .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - - where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. - - ``forward()`` will use a special optimized implementation if all of the following - conditions are met: - - - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This - restriction will be loosened in the future.) - - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - - training is disabled (using ``.eval()``) - - dropout is 0 - - ``add_bias_kv`` is ``False`` - - ``add_zero_attn`` is ``False`` - - ``batch_first`` is ``True`` and the input is batched - - ``kdim`` and ``vdim`` are equal to ``embed_dim`` - - at most one of ``key_padding_mask`` or ``attn_mask`` is passed - - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` - nor ``attn_mask`` is passed - - If the optimized implementation is in use, a - `NestedTensor `_ can be passed for - ``query``/``key``/``value`` to represent padding more efficiently than using a - padding mask. In this case, a `NestedTensor `_ - will be returned, and an additional speedup proportional to the fraction of the input - that is padding can be expected. - - Args: - embed_dim: Total dimension of the model. - num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split - across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). - dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). - bias: If specified, adds bias to input / output projection layers. Default: ``True``. - add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. - add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. - Default: ``False``. - kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). - vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` (seq, batch, feature). - - Examples:: - - >>> # xdoctest: +SKIP - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) - - """ - __constants__ = ["batch_first"] - bias_k: Optional[torch.Tensor] - bias_v: Optional[torch.Tensor] - - def __init__( - self, - embed_dim, - num_heads, - dropout=0.0, - bias=True, - add_bias_kv=False, - add_zero_attn=False, - kdim=None, - vdim=None, - batch_first=False, - linear1_cls=Linear, - linear2_cls=Linear, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super(MultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.kdim = kdim if kdim is not None else embed_dim - self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim - - self.num_heads = num_heads - self.dropout = dropout - self.batch_first = batch_first - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - if add_bias_kv: - self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - else: - self.bias_k = self.bias_v = None - - if linear1_cls == Linear: - if not self._qkv_same_embed_dim: - self.q_proj_weight = Parameter( - torch.empty((embed_dim, embed_dim), **factory_kwargs) - ) - self.k_proj_weight = Parameter( - torch.empty((embed_dim, self.kdim), **factory_kwargs) - ) - self.v_proj_weight = Parameter( - torch.empty((embed_dim, self.vdim), **factory_kwargs) - ) - self.register_parameter("in_proj_weight", None) - else: - self.in_proj_weight = Parameter( - torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) - ) - self.register_parameter("q_proj_weight", None) - self.register_parameter("k_proj_weight", None) - self.register_parameter("v_proj_weight", None) - - if bias: - self.in_proj_bias = Parameter( - torch.empty(3 * embed_dim, **factory_kwargs) - ) - else: - self.register_parameter("in_proj_bias", None) - self.out_proj = NonDynamicallyQuantizableLinear( - embed_dim, embed_dim, bias=bias, **factory_kwargs - ) - - self._reset_parameters() - else: - if not self._qkv_same_embed_dim: - raise NotImplementedError - else: - self.in_proj_linear = linear1_cls( - embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs - ) - self.in_proj_weight = self.in_proj_linear.weight - - self.register_parameter("q_proj_weight", None) - self.register_parameter("k_proj_weight", None) - self.register_parameter("v_proj_weight", None) - - if bias: - self.in_proj_bias = self.in_proj_linear.bias - else: - self.register_parameter("in_proj_bias", None) - - self.out_proj = linear2_cls( - embed_dim, embed_dim, bias=bias, **factory_kwargs - ) - - if self.bias_k is not None: - xavier_normal_(self.bias_k) - if self.bias_v is not None: - xavier_normal_(self.bias_v) - - self.add_zero_attn = add_zero_attn - - def _reset_parameters(self): - if self._qkv_same_embed_dim: - xavier_uniform_(self.in_proj_weight) - else: - xavier_uniform_(self.q_proj_weight) - xavier_uniform_(self.k_proj_weight) - xavier_uniform_(self.v_proj_weight) - - if self.in_proj_bias is not None: - constant_(self.in_proj_bias, 0.0) - constant_(self.out_proj.bias, 0.0) - - if self.bias_k is not None: - xavier_normal_(self.bias_k) - if self.bias_v is not None: - xavier_normal_(self.bias_v) - - def __setstate__(self, state): - # Support loading old MultiheadAttention checkpoints generated by v1.1.0 - if "_qkv_same_embed_dim" not in state: - state["_qkv_same_embed_dim"] = True - - super(MultiheadAttention, self).__setstate__(state) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` - or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, - :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. - Queries are compared against key-value pairs to produce the output. - See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` - or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, - :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. - See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when - ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. - See "Attention Is All You Need" for more details. - key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` - to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. - Binary and byte masks are supported. - For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for - the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. - need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. - Default: ``True``. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be - broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. - Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the - corresponding position is not allowed to attend. For a float mask, the mask values will be added to - the attention weight. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across - heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - - Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, - :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the - embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, - returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. - - .. note:: - `batch_first` argument is ignored for unbatched inputs. - """ - is_batched = query.dim() == 3 - if key_padding_mask is not None: - _kpm_dtype = key_padding_mask.dtype - if _kpm_dtype != torch.bool and not torch.is_floating_point( - key_padding_mask - ): - raise AssertionError( - "only bool and floating types of key_padding_mask are supported" - ) - why_not_fast_path = "" - if not is_batched: - why_not_fast_path = ( - f"input not batched; expected query.dim() of 3 but got {query.dim()}" - ) - elif query is not key or key is not value: - # When lifting this restriction, don't forget to either - # enforce that the dtypes all match or test cases where - # they don't! - why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" - elif ( - self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype - ): - # this case will fail anyway, but at least they'll get a useful error message. - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" - elif self.training: - why_not_fast_path = "training is enabled" - elif not self.batch_first: - why_not_fast_path = "batch_first was not True" - elif self.bias_k is not None: - why_not_fast_path = "self.bias_k was not None" - elif self.bias_v is not None: - why_not_fast_path = "self.bias_v was not None" - elif self.dropout: - why_not_fast_path = f"dropout was {self.dropout}, required zero" - elif self.add_zero_attn: - why_not_fast_path = "add_zero_attn was enabled" - elif not self._qkv_same_embed_dim: - why_not_fast_path = "_qkv_same_embed_dim was not True" - elif attn_mask is not None: - why_not_fast_path = "attn_mask was not None" - elif query.is_nested and key_padding_mask is not None: - why_not_fast_path = ( - "key_padding_mask is not supported with NestedTensor input" - ) - elif self.num_heads % 2 == 1: - why_not_fast_path = "num_heads is odd" - elif torch.is_autocast_enabled(): - why_not_fast_path = "autocast is enabled" - - if not why_not_fast_path: - tensor_args = ( - query, - key, - value, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - ) - # We have to use list comprehensions below because TorchScript does not support - # generator expressions. - if torch.overrides.has_torch_function(tensor_args): - why_not_fast_path = "some Tensor argument has_torch_function" - elif not all( - [ - (x is None or x.is_cuda or "cpu" in str(x.device)) - for x in tensor_args - ] - ): - why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any( - [x is not None and x.requires_grad for x in tensor_args] - ): - why_not_fast_path = ( - "grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad" - ) - if not why_not_fast_path: - return torch._native_multi_head_attention( - query, - key, - value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - key_padding_mask if key_padding_mask is not None else attn_mask, - need_weights, - average_attn_weights, - 1 - if key_padding_mask is not None - else 0 - if attn_mask is not None - else None, - ) - - any_nested = query.is_nested or key.is_nested or value.is_nested - assert not any_nested, ( - "MultiheadAttention does not support NestedTensor outside of its fast path. " - + f"The fast path was not hit because {why_not_fast_path}" - ) - - if self.batch_first and is_batched: - # make sure that the transpose op does not affect the "is" property - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - - if not self._qkv_same_embed_dim: - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, - key, - value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, - k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, - average_attn_weights=average_attn_weights, - ) - else: - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, - key, - value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - average_attn_weights=average_attn_weights, - ) - if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights - else: - return attn_output, attn_output_weights - - -class LayerNorm(nn.Module): - __constants__ = ["normalized_shape", "eps", "elementwise_affine"] - normalized_shape: Tuple[int, ...] - eps: float - elementwise_affine: bool - - def __init__( - self, - normalized_shape: _shape_t, - eps: float = 1e-5, - elementwise_affine: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super(LayerNorm, self).__init__() - if isinstance(normalized_shape, numbers.Integral): - # mypy error: incompatible types in assignment - normalized_shape = (normalized_shape,) # type: ignore[assignment] - self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs) - ) - self.bias = nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs) - ) - else: - self.register_parameter("weight", None) - self.register_parameter("bias", None) - - self.reset_parameters() - - def reset_parameters(self) -> None: - if self.elementwise_affine: - nn.init.ones_(self.weight) - nn.init.zeros_(self.bias) - - def forward(self, input: Tensor, embedding: Any = None) -> Tensor: - if isinstance(input, tuple): - input, embedding = input - return ( - F.layer_norm( - input, - self.normalized_shape, - self.weight, - self.bias, - self.eps, - ), - embedding, - ) - - assert embedding is None - return F.layer_norm( - input, self.normalized_shape, self.weight, self.bias, self.eps - ) - - def extra_repr(self) -> str: - return ( - "{normalized_shape}, eps={eps}, " - "elementwise_affine={elementwise_affine}".format(**self.__dict__) - ) - - -class AdaptiveLayerNorm(nn.Module): - r"""Adaptive Layer Normalization""" - - def __init__(self, d_model, norm) -> None: - super(AdaptiveLayerNorm, self).__init__() - self.project_layer = nn.Linear(d_model, 2 * d_model) - self.norm = norm - self.d_model = d_model - self.eps = self.norm.eps - - def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: - if isinstance(input, tuple): - input, embedding = input - weight, bias = torch.split( - self.project_layer(embedding), - split_size_or_sections=self.d_model, - dim=-1, - ) - return (weight * self.norm(input) + bias, embedding) - - weight, bias = torch.split( - self.project_layer(embedding), - split_size_or_sections=self.d_model, - dim=-1, - ) - return weight * self.norm(input) + bias - - -class TransformerEncoderLayer(nn.Module): - __constants__ = ["batch_first", "norm_first"] - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - batch_first: bool = False, - norm_first: bool = False, - device=None, - dtype=None, - linear1_self_attention_cls: nn.Module = nn.Linear, - linear2_self_attention_cls: nn.Module = nn.Linear, - linear1_feedforward_cls: nn.Module = nn.Linear, - linear2_feedforward_cls: nn.Module = nn.Linear, - layer_norm_cls: nn.Module = LayerNorm, - layer_norm_eps: float = 1e-5, - adaptive_layer_norm=False, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super(TransformerEncoderLayer, self).__init__() - self.self_attn = MultiheadAttention( - d_model, - nhead, - dropout=dropout, - batch_first=batch_first, - linear1_cls=linear1_self_attention_cls, - linear2_cls=linear2_self_attention_cls, - **factory_kwargs, - ) - - # Implementation of Feedforward model - self.linear1 = linear1_feedforward_cls( - d_model, dim_feedforward, **factory_kwargs - ) - self.dropout = nn.Dropout(dropout) - self.linear2 = linear2_feedforward_cls( - dim_feedforward, d_model, **factory_kwargs - ) - - self.norm_first = norm_first - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - # Legacy string support for activation function. - if isinstance(activation, str): - activation = _get_activation_fn(activation) - elif isinstance(activation, partial): - activation = activation(d_model) - # elif activation == BalancedDoubleSwish: - # activation = BalancedDoubleSwish(d_model) - - # # We can't test self.activation in forward() in TorchScript, - # # so stash some information about it instead. - # if activation is F.relu or isinstance(activation, torch.nn.ReLU): - # self.activation_relu_or_gelu = 1 - # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): - # self.activation_relu_or_gelu = 2 - # else: - # self.activation_relu_or_gelu = 0 - self.activation = activation - - norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) - # if layer_norm_cls == IdentityNorm: - # norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - # else: - if True: - norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) - - if adaptive_layer_norm: - self.norm1 = AdaptiveLayerNorm(d_model, norm1) - self.norm2 = AdaptiveLayerNorm(d_model, norm2) - else: - self.norm1 = norm1 - self.norm2 = norm2 - - def __setstate__(self, state): - super(TransformerEncoderLayer, self).__setstate__(state) - if not hasattr(self, "activation"): - self.activation = F.relu - - def forward( - self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - see the docs in Transformer class. - """ - x, stage_embedding = src, None - is_src_tuple = False - if isinstance(src, tuple): - x, stage_embedding = src - is_src_tuple = True - - if src_key_padding_mask is not None: - _skpm_dtype = src_key_padding_mask.dtype - if _skpm_dtype != torch.bool and not torch.is_floating_point( - src_key_padding_mask - ): - raise AssertionError( - "only bool and floating types of key_padding_mask are supported" - ) - - if self.norm_first: - x = x + self._sa_block( - self.norm1(x, stage_embedding), - src_mask, - src_key_padding_mask, - ) - x = x + self._ff_block(self.norm2(x, stage_embedding)) - else: - x = self.norm1( - x + self._sa_block(x, src_mask, src_key_padding_mask), - stage_embedding, - ) - x = self.norm2(x + self._ff_block(x), stage_embedding) - - if is_src_tuple: - return (x, stage_embedding) - return x - - # self-attention block - def _sa_block( - self, - x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - ) -> Tensor: - x = self.self_attn( - x, - x, - x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False, - )[0] - return self.dropout1(x) - - # feed forward block - def _ff_block(self, x: Tensor) -> Tensor: - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - return self.dropout2(x) - - -class TransformerEncoder(nn.Module): - r"""TransformerEncoder is a stack of N encoder layers. Users can build the - BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. - - Args: - encoder_layer: an instance of the TransformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - enable_nested_tensor: if True, input will automatically convert to nested tensor - (and convert back on output). This will improve the overall performance of - TransformerEncoder when padding rate is high. Default: ``True`` (enabled). - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = transformer_encoder(src) - """ - __constants__ = ["norm"] - - def __init__(self, encoder_layer, num_layers, norm=None): - super(TransformerEncoder, self).__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - src: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - return_layer_states: bool = False, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - return_layer_states: return layers' state (optional). - - Shape: - see the docs in Transformer class. - """ - if return_layer_states: - layer_states = [] # layers' output - output = src - for mod in self.layers: - output = mod( - output, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) - layer_states.append(output[0]) - - if self.norm is not None: - output = self.norm(output) - - return layer_states, output - - output = src - for mod in self.layers: - output = mod( - output, src_mask=mask, src_key_padding_mask=src_key_padding_mask - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - -def _get_clones(module, N): - return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) - - -def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: - if activation == "relu": - return F.relu - elif activation == "gelu": - return F.gelu - - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - - -class VALLE(nn.Module): - """It implements https://arxiv.org/abs/2301.02111 - "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" - """ - - def __init__( - self, - d_model: int, - nhead: int, - num_layers: int, - norm_first: bool = True, - add_prenet: bool = False, - decoder_cls=TransformerEncoder, - decoder_layer_cls=TransformerEncoderLayer, - prefix_mode: int = 0, - share_embedding: bool = True, - nar_scale_factor: float = 1.0, - prepend_bos: bool = False, - num_quantizers: int = 8, - **kwargs, - ): - """ - Args: - d_model: - The number of expected features in the input (required). - nhead: - The number of heads in the multiheadattention models (required). - num_layers: - The number of sub-decoder-layers in the decoder (required). - """ - super().__init__() - nar_d_model = int(d_model * nar_scale_factor) - - self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x - self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) - - # ID NUM_AUDIO_TOKENS -> PAD - # ID NUM_AUDIO_TOKENS + 1 -> BOS - self.ar_audio_prepend_bos = prepend_bos - self.ar_audio_embedding = TokenEmbedding( - d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) - ) - - # PreNet - if add_prenet: - self.ar_text_prenet = nn.Sequential( - Transpose(), - nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), - nn.BatchNorm1d(d_model), - nn.ReLU(), - nn.Dropout(0.5), - nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), - nn.BatchNorm1d(d_model), - nn.ReLU(), - nn.Dropout(0.5), - nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), - nn.BatchNorm1d(d_model), - nn.ReLU(), - nn.Dropout(0.5), - Transpose(), - nn.Linear(d_model, d_model), - ) - - self.ar_audio_prenet = nn.Sequential( - nn.Linear(d_model, 256), - nn.ReLU(), - nn.Dropout(0.25), - nn.Linear(256, 256), - nn.ReLU(), - nn.Dropout(0.25), - nn.Linear(256, d_model), - ) - else: - self.ar_text_prenet = nn.Identity() - self.ar_audio_prenet = nn.Identity() - - self.ar_text_position = SinePositionalEmbedding( - d_model, - dropout=0.1, - scale=False, - alpha=True, - ) - self.ar_audio_position = SinePositionalEmbedding( - d_model, - dropout=0.1, - scale=False, - alpha=True, - ) - - self.ar_decoder = decoder_cls( - decoder_layer_cls( - d_model, - nhead, - dim_feedforward=d_model * 4, - dropout=0.1, - batch_first=True, - norm_first=norm_first, - ), - num_layers=num_layers, - norm=LayerNorm(d_model) if norm_first else None, - ) - self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False) - - self.ar_accuracy_metric = MulticlassAccuracy( - NUM_AUDIO_TOKENS + 1, - top_k=10, - average="micro", - multidim_average="global", - ignore_index=NUM_AUDIO_TOKENS, - ) - - self.rng = random.Random(0) - self.num_heads = nhead - self.prefix_mode = prefix_mode - self.num_quantizers = num_quantizers - - assert num_quantizers >= 1 - if num_quantizers > 1: - self.nar_audio_embeddings = nn.ModuleList( - [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] - + [ - TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) - for i in range(num_quantizers - 1) - ] - ) # W_a - - # PreNet - if add_prenet: - self.nar_text_prenet = nn.Sequential( - Transpose(), - nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), - nn.BatchNorm1d(nar_d_model), - nn.ReLU(), - nn.Dropout(0.5), - nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), - nn.BatchNorm1d(nar_d_model), - nn.ReLU(), - nn.Dropout(0.5), - nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), - nn.BatchNorm1d(nar_d_model), - nn.ReLU(), - nn.Dropout(0.5), - Transpose(), - nn.Linear(nar_d_model, nar_d_model), - ) - self.nar_audio_prenet = nn.Sequential( - nn.Linear(nar_d_model, 256), - nn.ReLU(), - nn.Dropout(0.25), - nn.Linear(256, 256), - nn.ReLU(), - nn.Dropout(0.25), - nn.Linear(256, nar_d_model), - ) - else: - self.nar_text_prenet = nn.Identity() - self.nar_audio_prenet = nn.Identity() - - self.nar_text_position = SinePositionalEmbedding( - nar_d_model, - dropout=0.0, - scale=False, - alpha=False, - ) - self.nar_audio_position = SinePositionalEmbedding( - nar_d_model, - dropout=0.1, - scale=False, - alpha=False, - ) - - self.nar_decoder = decoder_cls( - decoder_layer_cls( - nar_d_model, - int(nhead * nar_scale_factor), - dim_feedforward=nar_d_model * 4, - dropout=0.1, - batch_first=True, - norm_first=norm_first, - adaptive_layer_norm=True, - ), - num_layers=int(num_layers * nar_scale_factor), - norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model)) - if norm_first - else None, - ) - self.nar_predict_layers = nn.ModuleList( - [ - nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) - for i in range(num_quantizers - 1) - ] - ) - self.nar_stage_embeddings = nn.ModuleList( - [TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)] - ) - - if share_embedding: - # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa - # NOTE(Feiteng): In the experiment, this undermines accuracy - # self.ar_predict_layer.weight = self.ar_audio_embedding.weight - - # We also share the parameters of the acoustic embedding layer and the output prediction layer, - # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. - for j in range(0, num_quantizers - 2): - self.nar_predict_layers[j].weight = self.nar_audio_embeddings[ - j + 2 - ].weight - - self.nar_accuracy_metric = MulticlassAccuracy( - NUM_AUDIO_TOKENS + 1, - top_k=10, - average="micro", - multidim_average="global", - ignore_index=NUM_AUDIO_TOKENS, - ) - - def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: - assert stage > 0 - if stage == 1: - for name, param in self.named_parameters(): - if name.startswith("ar_"): - print(f" AR parameter: {name}") - yield param - - if stage == 2: - for name, param in self.named_parameters(): - if name.startswith("nar_"): - print(f"NAR parameter: {name}") - yield param - - def stage_named_parameters( - self, stage: int = 1 - ) -> Iterator[Tuple[str, nn.Parameter]]: - assert stage > 0 - if stage == 1: - for pair in self.named_parameters(): - if pair[0].startswith("ar_"): - yield pair - - if stage == 2: - for pair in self.named_parameters(): - if pair[0].startswith("nar_"): - yield pair - - def pad_y_eos(self, y, y_mask_int, eos_id): - targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( - y_mask_int, (0, 1), value=1 - ) - # inputs, targets - if self.ar_audio_prepend_bos: - return ( - F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), - targets, - ) - - return targets[:, :-1], targets[:, 1:] - - def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): - # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds - # from the same utterance. - # We implement this differently. - if self.prefix_mode == 0: - # no prefix - prefix_len = 0 - y_emb = self.nar_audio_embeddings[0](y) - for j in range(1, nar_stage): - # Formula (4) (5) - y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) - elif self.prefix_mode == 1: - # prefix at begining - int_low = (0.25 * y_lens.min()).type(torch.int64).item() - prefix_len = torch.randint(int_low, int_low * 2, size=()).item() - prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames - - y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) - y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) - for j in range(1, self.num_quantizers): - y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) - if j < nar_stage: - y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) - y_emb = torch.concat([y_prompts, y_emb], axis=1) - elif self.prefix_mode in [2, 4]: - if self.prefix_mode == 2: - # random prefix - prefix_len = min(225, int(0.25 * y_lens.min().item())) - - y_prompts_codes = [] - for b in range(codes.shape[0]): - start = self.rng.randint(0, y_lens[b].item() - prefix_len) - y_prompts_codes.append( - torch.clone(codes[b, start : start + prefix_len]) - ) - codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS - y_prompts_codes = torch.stack(y_prompts_codes, dim=0) - else: - prefix_len = y_prompts_codes.shape[1] - - y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) - y_emb = self.nar_audio_embeddings[0](y) - for j in range(1, self.num_quantizers): - y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) - if j < nar_stage: - y_emb += self.nar_audio_embeddings[j](codes[..., j]) - y_emb = torch.concat([y_prompts, y_emb], axis=1) - else: - raise ValueError - - return y_emb, prefix_len - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: Union[torch.Tensor, PromptedFeatures], - y_lens: Union[torch.Tensor, PromptedFeatures], - reduction: str = "sum", - train_stage: int = 0, - **kwargs, - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: - """ - Args: - x: - A 2-D tensor of shape (N, S). - x_lens: - A 1-D tensor of shape (N,). It contains the number of tokens in `x` - before padding. - y: - A 3-D tensor of shape (N, T, 8). - y_lens: - A 1-D tensor of shape (N,). It contains the number of tokens in `x` - before padding. - train_stage: - 0: AR & NAR modules, 1: AR modules, 2: NAR modules - Returns: - Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. - """ - assert x.ndim == 2, x.shape - assert x_lens.ndim == 1, x_lens.shape - - y_prompts_codes = None - if isinstance(y, PromptedFeatures): - y_prompts_codes, y = y.data - prompts_len, y_lens = y_lens.data - assert prompts_len.min() == prompts_len.max() - assert self.prefix_mode == 4 - y_prompts_codes = y_prompts_codes.type(torch.int64) - - assert y.ndim == 3, y.shape - assert y_lens.ndim == 1, y_lens.shape - - # NOTE: x has been padded in TextTokenCollater - x_mask = make_pad_mask(x_lens).to(x.device) - y_mask = make_pad_mask(y_lens).to(y.device) - y_mask_int = y_mask.type(torch.int64) - - text = x - codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) - - y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS) - - x_len = x_lens.max() - - metrics = {} - total_loss = 0.0 - - xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) - if self.ar_audio_prepend_bos: - ar_xy_padding_mask = torch.concat( - [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 - ) - else: - ar_xy_padding_mask = xy_padding_mask - # AR Decoder - if train_stage in [0, 1]: - x = self.ar_text_embedding(text) - x = self.ar_text_prenet(x) - x = self.ar_text_position(x) - - y_len = y_lens.max() + int(self.ar_audio_prepend_bos) - - x_attn_mask = F.pad( - torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), - (0, y_len), - value=True, - ) - y_attn_mask = F.pad( - torch.triu( - torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), - diagonal=1, - ), - (x_len, 0), - value=False, - ) - xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) - - # merge key padding and attention masks - bsz, src_len = x.shape[0], x_len + y_len - _xy_padding_mask = ( - ar_xy_padding_mask.view(bsz, 1, 1, src_len) - .expand(-1, self.num_heads, -1, -1) - .reshape(bsz * self.num_heads, 1, src_len) - ) - xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) - - new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) - new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) - xy_attn_mask = new_attn_mask - - y_emb = self.ar_audio_embedding(y) - y_emb = self.ar_audio_prenet(y_emb) - y_pos = self.ar_audio_position(y_emb) - - xy_pos = torch.concat([x, y_pos], dim=1) - - xy_dec, _ = self.ar_decoder( - (xy_pos, None), - mask=xy_attn_mask, - # src_key_padding_mask=xy_padding_mask, - # is_causal=True, - ) - logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) - # loss - total_loss = F.cross_entropy(logits, targets, reduction=reduction) - - metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( - logits.detach(), targets - ).item() * y_lens.sum().type(torch.float32) - - if self.num_quantizers == 1: - return ((x, codes), total_loss, metrics) - - # Non-AR Decoders - if self.ar_audio_prepend_bos: - y = y[:, 1:] - if train_stage in [0, 2]: - num_nar_layers = self.num_quantizers - 1 - nar_stage = self.rng.choices( - [_k for _k in range(1, self.num_quantizers)], - weights=[1.0 / num_nar_layers] * num_nar_layers, - k=1, - )[0] - - x = self.nar_text_embedding(text) - x = self.nar_text_prenet(x) - x = self.nar_text_position(x) - - y_emb, prefix_len = self._prepare_prompts( - y, y_lens, codes, nar_stage, y_prompts_codes - ) - - y_len = y_lens.max() - targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int - if self.prefix_mode in [2, 4]: - xy_padding_mask = torch.concat( - [ - x_mask, - F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), - ], - dim=1, - ) - elif self.prefix_mode == 1: - targets = targets[:, prefix_len:] - - y_pos = self.nar_audio_prenet(y_emb) - y_pos = self.nar_audio_position(y_pos) - xy_pos = torch.concat([x, y_pos], dim=1) - xy_dec, _ = self.nar_decoder( - (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), - src_key_padding_mask=xy_padding_mask, - # is_causal=False, - ) - xy_dec = xy_dec[:, x_lens.max() + prefix_len :] - if self.prefix_mode == 4: - prefix_len = 0 # reset for Top10Accuracy metric - logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1) - - # loss - total_length = (y_lens).sum().type(torch.float32) - total_loss += F.cross_entropy( - logits, - targets, - ignore_index=NUM_AUDIO_TOKENS, - reduction=reduction, - ) * (total_length / (total_length - prefix_len * x.shape[0])) - metrics["NarTop10Accuracy"] = ( - self.nar_accuracy_metric( - F.pad( - logits.detach(), - (0, 0, 0, 1, 0, 0), - value=logits.min().cpu().item(), - ), - targets, - ).item() - * total_length - ) - - if train_stage == 0: - total_loss = total_loss / 2.0 - - return ((x, codes), total_loss, metrics) - - def inference( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: torch.Tensor, - enroll_x_lens: torch.Tensor, - top_k: int = -100, - temperature: float = 1.0, - top_p: float = 1.0, - ras: bool = False, - ) -> torch.Tensor: - """ - Args: - x: - A 2-D tensor of shape (1, S). - x_lens: - A 1-D tensor of shape (1,). It contains the number of tokens in `x` - before padding. - y: - A 3-D tensor of shape (1, T, 8). - top_k: (`optional`) int - The number of highest probability tokens to keep for top-k-filtering. Default to -100. - temperature: (`optional`) float - The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. - ras: (`optional`) bool - Whether to use repetition-aware sampling. Default to False. - Returns: - Return the predicted audio code matrix. - """ - assert x.ndim == 2, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.ndim == 3, y.shape - assert y.shape[0] == 1, y.shape - - assert torch.all(x_lens > 0) - - # NOTE: x has been padded in TextTokenCollater - text = x - x = self.ar_text_embedding(text) - x = self.ar_text_prenet(x) - x = self.ar_text_position(x) - - text_len = x_lens.max() - prompts = y - prefix_len = y.shape[1] - - # AR Decoder - # TODO: Managing decoder steps avoid repetitive computation - y = prompts[..., 0] - if self.ar_audio_prepend_bos: - y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) - - x_len = x_lens.max() - x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) - - while True: - y_emb = self.ar_audio_embedding(y) - y_emb = self.ar_audio_prenet(y_emb) - y_pos = self.ar_audio_position(y_emb) - xy_pos = torch.concat([x, y_pos], dim=1) - - y_len = y.shape[1] - x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), - value=True, - ) - y_attn_mask = F.pad( - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), - (x_len, 0), - value=False, - ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - y.device - ) - - xy_dec, _ = self.ar_decoder( - (xy_pos, None), - mask=xy_attn_mask, - ) - logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = topk_sampling( - logits, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_aware_sampling=ras, - preceding_tokens=y, - ) - - if ( - torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS - or samples[0, 0] == NUM_AUDIO_TOKENS - or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 - ): - if prompts.shape[1] == y.shape[1]: - raise SyntaxError("well trained model shouldn't reach here.") - break - - y = torch.concat([y, samples], dim=1) - - codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] - if self.num_quantizers == 1: - return torch.stack(codes, dim=-1) - - # Non-AR Decoders - y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :]) - - if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes - enrolled_len = enroll_x_lens.max().item() - # SOS + Synthesis Text + EOS - text = torch.concat( - [ - text[:, :1], - text[:, enrolled_len - 1 :], - ], - dim=1, - ) - text_len = text_len - (enrolled_len - 2) - assert text.shape[0] == 1 - - x = self.nar_text_embedding(text) - x = self.nar_text_prenet(x) - x = self.nar_text_position(x) - - if self.prefix_mode == 0: - for i, (predict_layer, embedding_layer) in enumerate( - zip( - self.nar_predict_layers, - self.nar_audio_embeddings[1:], - ) - ): - y_pos = self.nar_audio_prenet(y_emb) - y_pos = self.nar_audio_position(y_pos) - xy_pos = torch.concat([x, y_pos], dim=1) - - xy_dec, _ = self.nar_decoder( - (xy_pos, self.nar_stage_embeddings[i].weight) - ) - logits = predict_layer(xy_dec[:, text_len + prefix_len :]) - - samples = torch.argmax(logits, dim=-1) - codes.append(samples) - - if i < self.num_quantizers - 2: - y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) - y_emb[:, prefix_len:] += embedding_layer(samples) - else: - for j in range(1, self.num_quantizers): - y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) - - for i, (predict_layer, embedding_layer) in enumerate( - zip( - self.nar_predict_layers, - self.nar_audio_embeddings[1:], - ) - ): - y_pos = self.nar_audio_prenet(y_emb) - y_pos = self.nar_audio_position(y_pos) - xy_pos = torch.concat([x, y_pos], dim=1) - - xy_dec, _ = self.nar_decoder( - (xy_pos, self.nar_stage_embeddings[i].weight) - ) - logits = predict_layer(xy_dec[:, text_len + prefix_len :]) - - samples = torch.argmax(logits, dim=-1) - codes.append(samples) - - if i < self.num_quantizers - 2: - y_emb[:, prefix_len:] += embedding_layer(samples) - - assert len(codes) == self.num_quantizers - return torch.stack(codes, dim=-1) - - def visualize( - self, - predicts: Tuple[torch.Tensor], - batch: Dict[str, Union[List, torch.Tensor]], - tokenizer: TextTokenCollater, - output_dir: str, - limit: int = 4, - ) -> None: - audio_features = batch["features"].to("cpu").detach().numpy() - audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() - - tokens = batch["tokens"] - text_tokens, text_tokens_lens = tokenizer(tokens) - assert text_tokens.ndim == 2 - - texts = batch["text"] - utt_ids = [cut.id for cut in batch["cut"]] - - encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() - decoder_outputs = predicts[1] - if isinstance(decoder_outputs, list): - decoder_outputs = decoder_outputs[-1] - decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() - - vmin, vmax = 0, 1024 # Encodec - - num_figures = 3 - for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): - _ = plt.figure(figsize=(14, 8 * num_figures)) - - S = text_tokens_lens[b] - T = audio_features_lens[b] - - # encoder - plt.subplot(num_figures, 1, 1) - plt.title(f"Text: {text}") - plt.imshow( - X=np.transpose(encoder_outputs[b]), - cmap=plt.get_cmap("jet"), - aspect="auto", - interpolation="nearest", - ) - plt.gca().invert_yaxis() - plt.axvline(x=S - 0.4, linewidth=2, color="r") - plt.xlabel("Encoder Output") - plt.colorbar() - - # decoder - plt.subplot(num_figures, 1, 2) - plt.imshow( - X=np.transpose(decoder_outputs[b]), - cmap=plt.get_cmap("jet"), - aspect="auto", - interpolation="nearest", - vmin=vmin, - vmax=vmax, - ) - plt.gca().invert_yaxis() - plt.axvline(x=T - 0.4, linewidth=2, color="r") - plt.xlabel("Decoder Output") - plt.colorbar() - - # target - plt.subplot(num_figures, 1, 3) - plt.imshow( - X=np.transpose(audio_features[b]), - cmap=plt.get_cmap("jet"), - aspect="auto", - interpolation="nearest", - vmin=vmin, - vmax=vmax, - ) - plt.gca().invert_yaxis() - plt.axvline(x=T - 0.4, linewidth=2, color="r") - plt.xlabel("Decoder Target") - plt.colorbar() - - plt.savefig(f"{output_dir}/{utt_id}.png") - plt.close() - - -# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py -def top_k_top_p_filtering( - logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 -): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size, vocabulary size) - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - Make sure we keep at least min_tokens_to_keep per batch example in the output - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if top_k > 0: - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) - logits[indices_to_remove] = filter_value - return logits - - -def topk_sampling( - logits, - top_k=10, - top_p=1.0, - temperature=1.0, - repetition_aware_sampling=False, - preceding_tokens=None, -): - if temperature != 1.0: - logits = logits / temperature - # Top-p/top-k filtering - logits_filtered = top_k_top_p_filtering( - logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2 - ) - # Sample - probs = F.softmax(logits_filtered, dim=-1) - tokens = torch.multinomial(probs, num_samples=1) - - if repetition_aware_sampling: - window_size = 10 - threshold = 0.1 - # we first generate the target code ct′ - # by nucleus sampling with a pre-defined top-p value v. Then, we - # calculate the repetition ratio r of token ct′ - # in the preceding code sequence with a window size K. - # If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′ - # by - # random sampling from p(ct′ - # |x, c window_size: - preceding_tokens = preceding_tokens[:, -window_size:] - if preceding_tokens.shape[1] > 0: - for i, item in enumerate(preceding_tokens): - # check if the repeat ratio exceeds the threshold - if (item == tokens[i]).sum() / window_size > threshold: - # replace the target code ct′ by random sampling - probs = F.softmax(logits[i], dim=-1) - token_new = torch.multinomial(probs, num_samples=1) - tokens[i] = token_new - return tokens diff --git a/egs/xbmu_amdo31/ASR/README.md b/egs/xbmu_amdo31/ASR/README.md deleted file mode 100644 index 0a441d070..000000000 --- a/egs/xbmu_amdo31/ASR/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Introduction -About the XBMU-AMDO31 corpus -XBMU-AMDO31 is an open-source Amdo Tibetan speech corpus published by Northwest Minzu University. -publicly available on https://huggingface.co/datasets/syzym/xbmu_amdo31 - -XBMU-AMDO31 dataset is a speech recognition corpus of Amdo Tibetan dialect. -The open source corpus contains 31 hours of speech data and resources related -to build speech recognition systems,including transcribed texts and a Tibetan -pronunciation lexicon. -(The lexicon is a Tibetan lexicon of the Lhasa dialect, which has been reused -for the Amdo dialect because of the uniformity of the Tibetan language) -The dataset can be used to train a model for Amdo Tibetan Automatic Speech Recognition (ASR). - -This recipe includes some different ASR models trained with XBMU-AMDO31. - -[./RESULTS.md](./RESULTS.md) contains the latest results. \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/RESULTS.md b/egs/xbmu_amdo31/ASR/RESULTS.md deleted file mode 100644 index 1bd9b2e2b..000000000 --- a/egs/xbmu_amdo31/ASR/RESULTS.md +++ /dev/null @@ -1,92 +0,0 @@ -## Results - -### XBMU-AMDO31 BPE training result (Stateless Transducer) - -#### Pruned transducer stateless 5 - -[./pruned_transducer_stateless5](./pruned_transducer_stateless5) - -It uses pruned RNN-T. - -A pre-trained model and decoding logs can be found at - -You can use to deploy it. - -Number of model parameters: 87801200, i.e., 87.8 M - -| | test | dev | comment | -|------------------------|------|------|---------------------------------------| -| greedy search | 11.06| 11.73| --epoch 28 --avg 23 --max-duration 600| -| beam search | 10.64| 11.42| --epoch 28 --avg 23 --max-duration 600| -| modified beam search | 10.57| 11.24| --epoch 28 --avg 23 --max-duration 600| - - -Training command is: - -```bash -cd egs/xbmu_amdo31/ASR -./prepare.sh - -export CUDA_VISIBLE_DEVICES="0" - -./pruned_transducer_stateless5/train.py -``` - -**Caution**: It uses `--context-size=1`. - - -The decoding command is: -```bash -for method in greedy_search beam_search modified_beam_search; -do -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 23 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method $method -done -``` - -### pruned_transducer_stateless7 (zipformer) - -See for more details. - -[pruned_transducer_stateless7](./pruned_transducer_stateless7) - -You can find a pretrained model, training logs, decoding logs, and decoding -results at: - - -You can use to deploy it. - -Number of model parameters: 70369391, i.e., 70.37 M - -| | test | dev | comment | -|----------------------|------|------|----------------------------------------| -| greedy search | 10.06| 10.59| --epoch 23 --avg 11 --max-duration 600 | -| beam search | 9.77 | 10.11| --epoch 23 --avg 11 --max-duration 600 | -| modified beam search | 9.7 | 10.12| --epoch 23 --avg 11 --max-duration 600 | - -The training commands are: -```bash -export CUDA_VISIBLE_DEVICES="0" - -./pruned_transducer_stateless7/train.py -``` - -The decoding commands are: -```bash -for m in greedy_search beam_search modified_beam_search; do - for epoch in 23; do - for avg in 11; do - ./pruned_transducer_stateless7/decode.py \ - --epoch $epoch \ - --avg $avg \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method $m - done - done -done -``` diff --git a/egs/xbmu_amdo31/ASR/local/compile_hlg.py b/egs/xbmu_amdo31/ASR/local/compile_hlg.py deleted file mode 120000 index 471aa7fb4..000000000 --- a/egs/xbmu_amdo31/ASR/local/compile_hlg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/compile_lg.py b/egs/xbmu_amdo31/ASR/local/compile_lg.py deleted file mode 120000 index 462d6d3fb..000000000 --- a/egs/xbmu_amdo31/ASR/local/compile_lg.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py deleted file mode 100755 index a593e7be3..000000000 --- a/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the XBMU-AMDO31 dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import sentencepiece as spm -import torch -from filter_cuts import filter_cuts -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to the bpe.model. If not None, we will remove short and - long utterances before extracting features""", - ) - return parser.parse_args() - - -def compute_fbank_xbmu_amdo31(bpe_model: Optional[str] = None): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - if bpe_model: - logging.info(f"Loading {bpe_model}") - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - - dataset_parts = ( - "train", - "dev", - "test", - ) - prefix = "xbmu_amdo31" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if bpe_model: - cut_set = filter_cuts(cut_set, sp) - - if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - compute_fbank_xbmu_amdo31(bpe_model=args.bpe_model) diff --git a/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/filter_cuts.py b/egs/xbmu_amdo31/ASR/local/filter_cuts.py deleted file mode 120000 index 27aca1729..000000000 --- a/egs/xbmu_amdo31/ASR/local/filter_cuts.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py deleted file mode 120000 index c0aea1403..000000000 --- a/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang.py b/egs/xbmu_amdo31/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/xbmu_amdo31/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py deleted file mode 120000 index abc00d421..000000000 --- a/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py deleted file mode 120000 index 1d6ccbe33..000000000 --- a/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/train_bpe_model.py b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/xbmu_amdo31/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/prepare.sh b/egs/xbmu_amdo31/ASR/prepare.sh deleted file mode 100755 index 21836840c..000000000 --- a/egs/xbmu_amdo31/ASR/prepare.sh +++ /dev/null @@ -1,357 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -nj=15 -stage=-1 -stop_stage=100 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/xbmu_amdo31 -# You can find data, resource, etc, inside it. -# You can download them from https://huggingface.co/datasets/syzym/xbmu_amdo31 -# -# - $dl_dir/lm -# This directory contains the following files downloaded from -# git lfs install -# https://huggingface.co/syzym/xbmu_amdo31_lm -# -# - tibetan.3-gram.arpa -# - tibetan.4-gram.arpa -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" - # We assume that you have installed the git-lfs, if not, you could install it - # using: `sudo apt-get install git-lfs && git-lfs install` - git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1) - - if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then - git clone https://huggingface.co/syzym/xbmu_amdo31_lm $dl_dir/lm - pushd $dl_dir/lm - git lfs pull --include "tibetan.3-gram.arpa" - git lfs pull --include "tibetan.4-gram.arpa" - popd - fi -fi - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/xbmu_amdo31, - # you can create a symlink - # - # ln -sfv /path/to/xbmu_amdo31 $dl_dir/xbmu_amdo31 - # - - if [ ! -f $dl_dir/xbmu_amdo31 ]; then - git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1) - lhotse download xbmu-amdo31 $dl_dir - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare xbmu_amdo31 manifest" - # We assume that you have downloaded the xbmu_amdo31 corpus - # to $dl_dir/xbmu_amdo31 - if [ ! -f data/manifests/.xbmu_amdo31_manifests.done ]; then - mkdir -p data/manifests - lhotse prepare xbmu-amdo31 $dl_dir/xbmu_amdo31 data/manifests - touch data/manifests/.xbmu_amdo31_manifests.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for xbmu_amdo31" - if [ ! -f data/fbank/.xbmu_amdo31.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_xbmu_amdo31.py - touch data/fbank/.xbmu_amdo31.done - fi -fi - - - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" - if [ ! -f data/fbank/.msuan.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_musan.py - touch data/fbank/.msuan.done - fi -fi - - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" - lang_dir=data/lang_phone - mkdir -p $lang_dir - - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/xbmu_amdo31/resource/lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt - - ./local/generate_unique_lexicon.py --lang-dir $lang_dir - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi -fi - - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data to train phone based bigram P" - xbmu_amdo31_text=$dl_dir/xbmu_amdo31/data/transcript/transcript_clean.txt - xbmu_amdo31_train_uid=$dl_dir/xbmu_amdo31/data/transcript/xbmu_amdo31_train_uid - find $dl_dir/xbmu_amdo31/data/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '-' '{print $NF}' > $xbmu_amdo31_train_uid - awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $xbmu_amdo31_train_uid $xbmu_amdo31_text | - cut -d " " -f 2- > $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram P" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - > $lang_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/P.arpa - fi - - if [ ! -f $lang_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_dir/P.arpa > $lang_dir/P.fst.txt - fi - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $dl_dir/lm/tibetan.3-gram.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $dl_dir/lm/tibetan.4-gram.arpa > data/lm/G_4_gram.fst.txt - fi -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - done -fi - -# Compile LG for RNN-T fast_beam_search decoding -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Compile LG" - ./local/compile_lg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir - done -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Generate LM training data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - lang_dir=data/lang_bpe_${vocab_size} - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $dl_dir/lm/lm_train.txt \ - --lm-archive $out_dir/lm_data.pt - done -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Generate LM validation data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/valid.txt ]; then - files=$dl_dir/xbmu_amdo31/data/transcript/dev_text - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > $out_dir/valid.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/valid.txt \ - --lm-archive $out_dir/lm_data-valid.pt - done -fi - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Generate LM test data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/test.txt ]; then - files=$dl_dir/xbmu_amdo31/data/transcript/test_text - cat $f | cut -d " " -f 2- > $out_dir/test.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/test.txt \ - --lm-archive $out_dir/lm_data-test.pt - done -fi - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Sort LM training data" - # Sort LM training data by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of BPE tokens - # in a sentence. - - for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-valid.pt \ - --out-lm-data $out_dir/sorted_lm_data-valid.pt \ - --out-statistics $out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-test.pt \ - --out-lm-data $out_dir/sorted_lm_data-test.pt \ - --out-statistics $out_dir/statistics-test.txt - done -fi diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py deleted file mode 100644 index 7b37b1331..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# Copyright 2022 Northwest Minzu University (Author: Senyan Li) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import CutConcatenate # noqa F401 for PrecomputedFeatures -from lhotse.dataset import ( - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class Xbmu_AmdoAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - f = self.args.manifest_dir / "xbmu_amdo31_cuts_train.jsonl.gz" - logging.info(f"About to get train cuts from {f}") - cuts_train = load_manifest_lazy(f) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - f = self.args.manifest_dir / "xbmu_amdo31_cuts_dev.jsonl.gz" - logging.info(f"About to get valid cuts from {f}") - cuts_valid = load_manifest_lazy(f) - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - f = self.args.manifest_dir / "xbmu_amdo31_cuts_test.jsonl.gz" - logging.info(f"About to get test cuts from {f}") - cuts_test = load_manifest_lazy(f) - return cuts_test diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py deleted file mode 120000 index c7c1a4b6e..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py deleted file mode 100755 index b77f734e3..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py +++ /dev/null @@ -1,964 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method greedy_search -(2) beam search (not recommended) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 -(3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 -(4) fast beam search (one best) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -(5) fast beam search (nbest) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 -(7) fast beam search (with LG) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(8) modified beam search with RNNLM shallow fusion (with LG) -./pruned_transducer_stateless5/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 - - -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import Xbmu_AmdoAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_rnnlm_shallow_fusion, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is fast_beam_search_LG, - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is fast_beam_search_LG, - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", - type=str2bool, - default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if ( - params.decoding_method == "fast_beam_search" - or params.decoding_method == "fast_beam_search_LG" - ): - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - if params.decoding_method == "fast_beam_search": - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - logging.info(f"Decoding {batch_idx}-th batch") - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - Xbmu_AmdoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_rnnlm_shallow_fusion", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - rnn_lm_model = None - rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, - ) - assert params.rnn_lm_avg == 1 - - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - rnn_lm_model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - xbmu_amdo = Xbmu_AmdoAsrDataModule(args) - - test_cuts = xbmu_amdo.test_cuts() - - test_dl = xbmu_amdo.test_dataloaders(test_cuts) - - test_sets = ["test"] - test_dl = [test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py deleted file mode 120000 index d59ef95f7..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py deleted file mode 120000 index 722e1c894..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py deleted file mode 120000 index f58253127..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py deleted file mode 100755 index 54f656859..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,287 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./pruned_transducer_stateless5/export.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `pruned_transducer_stateless5/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless5/decode.py \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--streaming-model", - type=str2bool, - default=False, - help="""Whether to export a streaming model, if the models in exp-dir - are streaming model, this should be True. - """, - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if params.streaming_model: - assert params.causal_convolution - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - convert_scaled_to_non_scaled(model, inplace=True) - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py deleted file mode 120000 index 9052f3cbb..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py deleted file mode 120000 index b82e115fc..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py deleted file mode 120000 index a99e74334..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py deleted file mode 120000 index 0a2f285aa..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py deleted file mode 100755 index 2c106c4cb..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py +++ /dev/null @@ -1,345 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by -./pruned_transducer_stateless5/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py deleted file mode 120000 index c10cdfe12..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py deleted file mode 120000 index db93d155b..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py deleted file mode 120000 index 1199a61d6..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py deleted file mode 120000 index f29284163..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py deleted file mode 100755 index 9aad32014..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless4/test_model.py -""" - -from train import get_params, get_transducer_model - - -def test_model_1(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = 24 - params.dim_feedforward = 1536 # 384 * 4 - params.encoder_dim = 384 - model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf -def test_model_M(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = 18 - params.dim_feedforward = 1024 - params.encoder_dim = 256 - params.nhead = 4 - params.decoder_dim = 512 - params.joiner_dim = 512 - model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - -def main(): - # test_model_1() - test_model_M() - - -if __name__ == "__main__": - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py deleted file mode 100755 index a6fa46b17..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py +++ /dev/null @@ -1,1187 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import Xbmu_AmdoAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - display_and_save_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value used to penalize symbol delay, - to encourage streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details.""", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - dynamic_chunk_training=params.dynamic_chunk_training, - short_chunk_size=params.short_chunk_size, - num_left_chunks=params.num_left_chunks, - causal=params.causal_convolution, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - reduction="none", - delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - # If the batch contains more than 10 utterances AND - # if either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if params.dynamic_chunk_training: - assert ( - params.causal_convolution - ), "dynamic_chunk_training requires causal convolution" - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - xbmu_amdo = Xbmu_AmdoAsrDataModule(args) - - train_cuts = xbmu_amdo.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = xbmu_amdo.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = xbmu_amdo.valid_cuts() - valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts) - - if params.start_batch <= 0 and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - - -def main(): - parser = get_parser() - Xbmu_AmdoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py deleted file mode 120000 index c473a600a..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless5/asr_datamodule.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py deleted file mode 100755 index e334e690a..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py +++ /dev/null @@ -1,837 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import Xbmu_AmdoAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - Xbmu_AmdoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - xbmu_amdo = Xbmu_AmdoAsrDataModule(args) - - test_cuts = xbmu_amdo.test_cuts() - - test_dl = xbmu_amdo.test_dataloaders(test_cuts) - - test_sets = [ - "test", - ] - test_dl = [ - test_dl, - ] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py deleted file mode 120000 index 8283d8c5a..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py deleted file mode 120000 index f58253127..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py deleted file mode 120000 index 2713792e6..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py deleted file mode 120000 index a44034e34..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py deleted file mode 120000 index 0f0c3c90a..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py deleted file mode 120000 index 0d8bc665b..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py deleted file mode 120000 index 8a05abb5f..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 100755 index 6995ff2ff..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,356 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py deleted file mode 120000 index f9960e5c6..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py deleted file mode 120000 index 7ceac5d10..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py deleted file mode 100755 index dd72551d9..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ /dev/null @@ -1,1221 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 - -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import Xbmu_AmdoAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.err import raise_grad_scale_is_too_small_error -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise_grad_scale_is_too_small_error(cur_grad_scale) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - xbmu_amdo = Xbmu_AmdoAsrDataModule(args) - - train_cuts = xbmu_amdo.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = xbmu_amdo.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = xbmu_amdo.valid_cuts() - valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - Xbmu_AmdoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py deleted file mode 120000 index f2f66041e..000000000 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/shared b/egs/xbmu_amdo31/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/xbmu_amdo31/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/yesno/ASR/README.md b/egs/yesno/ASR/README.md deleted file mode 100644 index c9a2b56b1..000000000 --- a/egs/yesno/ASR/README.md +++ /dev/null @@ -1,14 +0,0 @@ -## Yesno recipe - -This is the simplest ASR recipe in `icefall`. - -It can be run on CPU and takes less than 30 seconds to -get the following WER: - -``` -[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] -``` - -Please refer to - -for detailed instructions. diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py deleted file mode 100755 index e0a94bf08..000000000 --- a/egs/yesno/ASR/local/compile_hlg.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python3 - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str) -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - logging.info("Loading G.fst.txt") - with open("data/lm/G.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py deleted file mode 100755 index 75d95df68..000000000 --- a/egs/yesno/ASR/local/compute_fbank_yesno.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file computes fbank features of the yesno dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or it wastes a -# lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_yesno(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - # This dataset is rather small, so we use only one job - num_jobs = min(1, os.cpu_count()) - num_mel_bins = 23 - - dataset_parts = ( - "train", - "test", - ) - prefix = "yesno" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}" - if cuts_file.is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 1, # use one job - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(cuts_file) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_yesno() diff --git a/egs/yesno/ASR/local/prepare_lang.py b/egs/yesno/ASR/local/prepare_lang.py deleted file mode 100755 index f7fde7796..000000000 --- a/egs/yesno/ASR/local/prepare_lang.py +++ /dev/null @@ -1,367 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon - -Lexicon = List[Tuple[str, List[str]]] - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - out_dir = Path("data/lang_phone") - lexicon_filename = out_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/local/prepare_lang_fst.py b/egs/yesno/ASR/local/prepare_lang_fst.py deleted file mode 120000 index c5787c534..000000000 --- a/egs/yesno/ASR/local/prepare_lang_fst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh deleted file mode 100755 index 41db0cf7c..000000000 --- a/egs/yesno/ASR/prepare.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download - -lang_dir=data/lang_phone -lm_dir=data/lm - -. shared/parse_options.sh || exit 1 - -mkdir -p $lang_dir -mkdir -p $lm_dir - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - mkdir -p $dl_dir - - if [ ! -f $dl_dir/waves_yesno/.completed ]; then - lhotse download yesno $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare yesno manifest" - mkdir -p data/manifests - lhotse prepare yesno $dl_dir/waves_yesno data/manifests -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute fbank for yesno" - mkdir -p data/fbank - ./local/compute_fbank_yesno.py -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare lang" - # NOTE: " SIL" is added for implementation convenience - # as the graph compiler code requires that there is a OOV word - # in the lexicon. - ( - echo " SIL" - echo "YES Y" - echo "NO N" - echo " SIL" - ) > $lang_dir/lexicon.txt - - ./local/prepare_lang.py - ./local/prepare_lang_fst.py --lang-dir ./data/lang_phone --has-silence 1 -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare G" - # We use a unigram G - cat < $lm_dir/G.arpa - -\data\\ -ngram 1=4 - -\1-grams: --1 NO --1 YES --99 --1 - -\end\\ - -EOF - - if [ ! -f $lm_dir/G.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/words.txt" \ - --disambig-symbol='#0' \ - $lm_dir/G.arpa > $lm_dir/G.fst.txt - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compile HLG" - if [ ! -f $lang_dir/HLG.pt ]; then - ./local/compile_hlg.py --lang-dir $lang_dir - fi -fi diff --git a/egs/yesno/ASR/shared b/egs/yesno/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/yesno/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file diff --git a/egs/yesno/ASR/tdnn/README.md b/egs/yesno/ASR/tdnn/README.md deleted file mode 100644 index 1b7ddcaf1..000000000 --- a/egs/yesno/ASR/tdnn/README.md +++ /dev/null @@ -1,8 +0,0 @@ - -## How to run this recipe - -You can find detailed instructions by visiting - - -It describes how to run this recipe and how to use -a pre-trained model with `./pretrained.py`. diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py deleted file mode 100644 index 99f2a6d08..000000000 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import List - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool - - -class YesNoAsrDataModule(DataModule): - """ - DataModule for k2 ASR experiments. - It assumes there is always one train dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--feature-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=30.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=False, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=10, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=False, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - def train_dataloaders(self) -> DataLoader: - logging.info("About to get train cuts") - cuts_train = self.train_cuts() - - logging.info("About to create train dataset") - transforms = [] - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=23)) - ), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - - return train_dl - - def test_dataloaders(self) -> DataLoader: - logging.info("About to get test cuts") - cuts_test = self.test_cuts() - - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) - if self.args.on_the_fly_feats - else PrecomputedFeatures() - ), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts_test, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.feature_dir / "yesno_cuts_train.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - cuts_test = load_manifest_lazy( - self.args.feature_dir / "yesno_cuts_test.jsonl.gz" - ) - return cuts_test diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py deleted file mode 100755 index f520607af..000000000 --- a/egs/yesno/ASR/tdnn/decode.py +++ /dev/null @@ -1,319 +0,0 @@ -#!/usr/bin/env python3 - - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import YesNoAsrDataModule -from model import Tdnn - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import get_lattice, one_best_decoding -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=14, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=2, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("tdnn/exp/"), - "lang_dir": Path("data/lang_phone"), - "feature_dim": 23, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - batch: dict, - word_table: k2.SymbolTable, -) -> List[List[int]]: - """Decode one batch and return the result in a list-of-list. - Each sub list contains the word IDs for an utterance in the batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding. - - params.method is "nbest", it uses nbest decoding. - - model: - The neural model. - HLG: - The decoding graph. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) - word_table: - It is the word symbol table. - Returns: - Return the decoding result. `len(ans)` == batch size. - """ - device = HLG.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - nnet_output = model(feature) - # nnet_output is (N, T, C) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - word_table: k2.SymbolTable, -) -> List[Tuple[str, List[str], List[str]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. - word_table: - It is word symbol table. - Returns: - Return a tuple contains two elements (ref_text, hyp_text): - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - word_table=word_table, - ) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - exp_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -) -> None: - """Save results to `exp_dir`. - Args: - exp_dir: - The output directory. This function create the following files inside - this directory: - - - recogs-{test_set_name}.text - - It contains the reference and hypothesis results, like below:: - - ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] - hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] - ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] - hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] - - - errs-{test_set_name}.txt - - It contains the detailed WER. - test_set_name: - The name of the test set, which will be part of the result filename. - results: - A list of tuples, each of which contains (ref_words, hyp_words). - Returns: - Return None. - """ - recog_path = exp_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = exp_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - write_error_stats(f, f"{test_set_name}", results) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - -@torch.no_grad() -def main(): - parser = get_parser() - YesNoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - params["env_info"] = get_env_info() - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) - HLG = HLG.to(device) - assert HLG.requires_grad is False - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_token_id + 1, # +1 for the blank symbol - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") - return - - model.to(device) - model.eval() - - # we need cut ids to display recognition results. - args.return_cuts = True - yes_no = YesNoAsrDataModule(args) - test_dl = yes_no.test_dataloaders() - results = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - word_table=lexicon.word_table, - ) - - save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/tdnn/export.py b/egs/yesno/ASR/tdnn/export.py deleted file mode 100755 index c40cf8cd1..000000000 --- a/egs/yesno/ASR/tdnn/export.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file is for exporting trained models to a checkpoint -or to a torchscript model. - -(1) Generate the checkpoint tdnn/exp/pretrained.pt - -./tdnn/export.py \ - --epoch 14 \ - --avg 2 - -See ./tdnn/pretrained.py for how to use the generated file. - -(2) Generate torchscript model tdnn/exp/cpu_jit.pt - -./tdnn/export.py \ - --epoch 14 \ - --avg 2 \ - --jit 1 - -See ./tdnn/jit_pretrained.py for how to use the generated file. -""" - -import argparse -import logging - -import torch -from model import Tdnn -from train import get_params - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=14, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=2, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_token_id + 1, # +1 for the blank symbol - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/yesno/ASR/tdnn/export_onnx.py b/egs/yesno/ASR/tdnn/export_onnx.py deleted file mode 100755 index 2436ca81b..000000000 --- a/egs/yesno/ASR/tdnn/export_onnx.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file is for exporting trained models to onnx. - -Usage: - - ./tdnn/export_onnx.py \ - --epoch 14 \ - --avg 2 - -The above command generates the following two files: - - ./exp/model-epoch-14-avg-2.onnx - - ./exp/model-epoch-14-avg-2.int8.onnx - -See ./tdnn/onnx_pretrained.py for how to use them. -""" - -import argparse -import logging -from typing import Dict - -import onnx -import torch -from model import Tdnn -from onnxruntime.quantization import QuantType, quantize_dynamic -from train import get_params - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=14, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=2, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_token_id + 1, # +1 for the blank symbol - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - model.to("cpu") - model.eval() - - N = 1 - T = 100 - C = params.feature_dim - x = torch.rand(N, T, C) - - opset_version = 13 - onnx_filename = f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.onnx" - torch.onnx.export( - model, - x, - onnx_filename, - verbose=False, - opset_version=opset_version, - input_names=["x"], - output_names=["log_prob"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "log_prob": {0: "N", 1: "T"}, - }, - ) - - logging.info(f"Saved to {onnx_filename}") - meta_data = { - "model_type": "tdnn", - "version": "1", - "model_author": "k2-fsa", - "comment": "non-streaming tdnn for the yesno recipe", - "vocab_size": max_token_id + 1, - } - - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=onnx_filename, meta_data=meta_data) - - logging.info("Generate int8 quantization models") - onnx_filename_int8 = ( - f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.int8.onnx" - ) - - quantize_dynamic( - model_input=onnx_filename, - model_output=onnx_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - logging.info(f"Saved to {onnx_filename_int8}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py deleted file mode 100755 index 6c643c263..000000000 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ /dev/null @@ -1,198 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file shows how to use a torchscript model for decoding. - -Usage: - - ./tdnn/jit_pretrained.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -Note that to generate ./tdnn/exp/cpu_jit.pt, -you can use ./export.py --jit 1 -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 - to obtain it - """, - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. ", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - "num_classes": 4, # [, N, SIL, Y] - "sample_rate": 8000, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - if sample_rate != expected_sample_rate: - wave = torchaudio.functional.resample( - wave, - orig_freq=sample_rate, - new_freq=expected_sample_rate, - ) - - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Loading torchscript model") - model = torch.jit.load(args.nn_model) - model.eval() - model.to(device) - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - nnet_output = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/yesno/ASR/tdnn/model.py b/egs/yesno/ASR/tdnn/model.py deleted file mode 100755 index 52cff37e0..000000000 --- a/egs/yesno/ASR/tdnn/model.py +++ /dev/null @@ -1,81 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang) - - -import torch -import torch.nn as nn - - -class Tdnn(nn.Module): - def __init__(self, num_features: int, num_classes: int): - """ - Args: - num_features: - Model input dimension. - num_classes: - Model output dimension - """ - super().__init__() - - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=32, - kernel_size=3, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - nn.Conv1d( - in_channels=32, - out_channels=32, - kernel_size=5, - dilation=2, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - nn.Conv1d( - in_channels=32, - out_channels=32, - kernel_size=5, - dilation=4, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - ) - self.output_linear = nn.Linear(in_features=32, out_features=num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The input tensor with shape [N, T, C] - - Returns: - The output tensor has shape [N, T, C] - """ - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] - x = self.tdnn(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] - x = self.output_linear(x) - x = nn.functional.log_softmax(x, dim=-1) - return x - - -def test_tdnn(): - num_features = 23 - num_classes = 4 - model = Tdnn(num_features=num_features, num_classes=num_classes) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - N = 2 - T = 100 - C = num_features - x = torch.randn(N, T, C) - y = model(x) - print(x.shape) - print(y.shape) - - -if __name__ == "__main__": - test_tdnn() diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py deleted file mode 100755 index 968a9e9a8..000000000 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ /dev/null @@ -1,244 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file shows how to use an ONNX model for decoding with onnxruntime. - -Usage: - -(1) Use a not quantized ONNX model, i.e., a float32 model - - ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -(2) Use a quantized ONNX model, i.e., an int8 model - - ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -Note that to generate ./tdnn/exp/model-epoch-14-avg-2.onnx, -and ./tdnn/exp/model-epoch-14-avg-2.onnx, -you can use ./export_onnx.py --epoch 14 --avg 2 -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts - - -class OnnxModel: - def __init__(self, nn_model: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - self.model = ort.InferenceSession( - nn_model, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - meta = self.model.get_modelmeta().custom_metadata_map - self.vocab_size = int(meta["vocab_size"]) - - def run( - self, - x: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - Returns: - Return a 3-D tensor log_prob of shape (N, T, C) - """ - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - }, - ) - return torch.from_numpy(out[0]) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 - to obtain it - """, - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. ", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - if sample_rate != expected_sample_rate: - wave = torchaudio.functional.resample( - wave, - orig_freq=sample_rate, - new_freq=expected_sample_rate, - ) - - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - "sample_rate": 8000, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"device: {device}") - - logging.info(f"Loading onnx model {params.nn_model}") - model = OnnxModel(params.nn_model) - - logging.info(f"Loading HLG from {args.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - nnet_output = model.run(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py deleted file mode 100755 index bea520998..000000000 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ /dev/null @@ -1,222 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file shows how to use a checkpoint for decoding. - -Usage: - - ./tdnn/pretrained.py \ - --checkpoint ./tdnn/exp/pretrained.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -Note that to generate ./tdnn/exp/pretrained.pt, -you can use ./export.py -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from model import Tdnn -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint(). " - "You can use ./tdnn/export.py to obtain it.", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. ", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - "num_classes": 4, # [, N, SIL, Y] - "sample_rate": 8000, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - if sample_rate != expected_sample_rate: - wave = torchaudio.functional.resample( - wave, - orig_freq=sample_rate, - new_freq=expected_sample_rate, - ) - - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - - model = Tdnn( - num_features=params.feature_dim, - num_classes=params.num_classes, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - nnet_output = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py deleted file mode 100755 index 335493491..000000000 --- a/egs/yesno/ASR/tdnn/train.py +++ /dev/null @@ -1,575 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import YesNoAsrDataModule -from lhotse.utils import fix_random_seed -from model import Tdnn -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=15, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - """ - params = AttributeDict( - { - "exp_dir": Path("tdnn/exp"), - "lang_dir": Path("data/lang_phone"), - "lr": 1e-2, - "feature_dim": 23, - "weight_decay": 1e-6, - "start_epoch": 0, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 20, - "valid_interval": 10, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Tdnn in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - with torch.set_grad_enabled(is_training): - nnet_output = model(feature) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervisions = batch["supervisions"] - texts = supervisions["text"] - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - ) - - loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - params["env_info"] = get_env_info() - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"device: {device}") - - graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = optim.SGD( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - yes_no = YesNoAsrDataModule(args) - train_dl = yes_no.train_dataloaders() - - # There are only 60 waves: 30 files are used for training - # and the remaining 30 files are used for testing. - # We use test data as validation. - valid_dl = yes_no.test_dataloaders() - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=None, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - YesNoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/transducer/__init__.py b/egs/yesno/ASR/transducer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/yesno/ASR/transducer/asr_datamodule.py b/egs/yesno/ASR/transducer/asr_datamodule.py deleted file mode 120000 index c9c8adb57..000000000 --- a/egs/yesno/ASR/transducer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn/asr_datamodule.py \ No newline at end of file diff --git a/egs/yesno/ASR/transducer/beam_search.py b/egs/yesno/ASR/transducer/beam_search.py deleted file mode 100644 index b98090636..000000000 --- a/egs/yesno/ASR/transducer/beam_search.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import torch -from transducer.model import Transducer - - -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[str]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - hyp = [] - max_u = 1000 # terminate after this number of steps - u = 0 - - while t < T and u < max_u: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (N, 1, 1) - # TODO: Use logits.argmax() - y = log_prob.argmax() - if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out, (h, c) = model.decoder(y, (h, c)) - u += 1 - else: - t += 1 - id2word = {1: "YES", 2: "NO"} - - hyp = [id2word[i] for i in hyp] - - return hyp diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py deleted file mode 100755 index 7f13e417a..000000000 --- a/egs/yesno/ASR/transducer/decode.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import torch -import torch.nn as nn -from asr_datamodule import YesNoAsrDataModule -from transducer.beam_search import greedy_search -from transducer.decoder import Decoder -from transducer.encoder import Tdnn -from transducer.joiner import Joiner -from transducer.model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=125, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - parser.add_argument( - "--exp-dir", - type=str, - default="transducer/exp", - help="Directory from which to load the checkpoints", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - # encoder/decoder params - "vocab_size": 3, # blank, yes, no - "blank_id": 0, - "embedding_dim": 32, - "hidden_dim": 16, - "num_decoder_layers": 4, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, -) -> List[List[int]]: - """Decode one batch and return the result in a list-of-list. - Each sub list contains the word IDs for an utterance in the batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding. - - params.method is "nbest", it uses nbest decoding. - - model: - The neural model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) - Returns: - Return the decoding result. `len(ans)` == batch size. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - feature_lens = batch["supervisions"]["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - hyp = greedy_search(model=model, encoder_out=encoder_out_i) - hyps.append(hyp) - - # hyps = [[word_table[i] for i in ids] for ids in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, -) -> List[Tuple[List[int], List[int]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - Returns: - Return a tuple contains two elements (ref_text, hyp_text): - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps = decode_one_batch( - params=params, - model=model, - batch=batch, - ) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - exp_dir: Path, - test_set_name: str, - results: List[Tuple[List[int], List[int]]], -) -> None: - """Save results to `exp_dir`. - Args: - exp_dir: - The output directory. This function create the following files inside - this directory: - - - recogs-{test_set_name}.text - - It contains the reference and hypothesis results, like below:: - - ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] - hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] - ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] - hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] - - - errs-{test_set_name}.txt - - It contains the detailed WER. - test_set_name: - The name of the test set, which will be part of the result filename. - results: - A list of tuples, each of which contains (ref_words, hyp_words). - Returns: - Return None. - """ - recog_path = exp_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = exp_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - write_error_stats(f, f"{test_set_name}", results) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - -def get_transducer_model(params: AttributeDict): - encoder = Tdnn( - num_features=params.feature_dim, - output_dim=params.hidden_dim, - ) - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.hidden_dim, - embedding_dropout=0.4, - rnn_dropout=0.4, - ) - joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) - return transducer - - -@torch.no_grad() -def main(): - parser = get_parser() - YesNoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - params["env_info"] = get_env_info() - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = get_transducer_model(params) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - model.to(device) - model.eval() - model.device = device - - # we need cut ids to display recognition results. - args.return_cuts = True - yes_no = YesNoAsrDataModule(args) - test_dl = yes_no.test_dataloaders() - results = decode_dataset( - dl=test_dl, - params=params, - model=model, - ) - - save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/transducer/decoder.py b/egs/yesno/ASR/transducer/decoder.py deleted file mode 100644 index 7ae540d03..000000000 --- a/egs/yesno/ASR/transducer/decoder.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import torch -import torch.nn as nn - - -class Decoder(nn.Module): - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - num_layers: int, - hidden_dim: int, - embedding_dropout: float = 0.0, - rnn_dropout: float = 0.0, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - num_layers: - Number of RNN layers. - hidden_dim: - Hidden dimension of RNN layers. - embedding_dropout: - Dropout rate for the embedding layer. - rnn_dropout: - Dropout for RNN layers. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.embedding_dropout = nn.Dropout(embedding_dropout) - self.rnn = nn.LSTM( - input_size=embedding_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - dropout=rnn_dropout, - ) - self.blank_id = blank_id - self.output_linear = nn.Linear(hidden_dim, hidden_dim) - - def forward( - self, - y: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - y: - A 2-D tensor of shape (N, U). - states: - A tuple of two tensors containing the states information of - RNN layers in this decoder. - Returns: - Return a tuple containing: - - - rnn_output, a tensor of shape (N, U, C) - - (h, c), which contain the state information for RNN layers. - Both are of shape (num_layers, N, C) - """ - embedding_out = self.embedding(y) - embedding_out = self.embedding_dropout(embedding_out) - rnn_out, (h, c) = self.rnn(embedding_out, states) - out = self.output_linear(rnn_out) - - return out, (h, c) diff --git a/egs/yesno/ASR/transducer/encoder.py b/egs/yesno/ASR/transducer/encoder.py deleted file mode 100644 index 8c50df293..000000000 --- a/egs/yesno/ASR/transducer/encoder.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn - - -# We use a TDNN model as encoder, as it works very well with CTC training -# for this tiny dataset. -class Tdnn(nn.Module): - def __init__(self, num_features: int, output_dim: int): - """ - Args: - num_features: - Model input dimension. - ouput_dim: - Model output dimension - """ - super().__init__() - - # Note: We don't use paddings inside conv layers - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=32, - kernel_size=3, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - nn.Conv1d( - in_channels=32, - out_channels=32, - kernel_size=5, - dilation=2, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - nn.Conv1d( - in_channels=32, - out_channels=32, - kernel_size=5, - dilation=4, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - ) - self.output_linear = nn.Linear(in_features=32, out_features=output_dim) - - def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The input tensor with shape (N, T, C) - x_lens: - It contains the number of frames in each utterance in x - before padding. - - Returns: - Return a tuple with 2 tensors: - - - logits, a tensor of shape (N, T, C) - - logit_lens, a tensor of shape (N,) - """ - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.tdnn(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - logits = self.output_linear(x) - - # the first conv layer reduces T by 3-1 frames - # the second layer reduces T by (5-1)*2 frames - # the second layer reduces T by (5-1)*4 frames - # Number of output frames is 2 + 4*2 + 4*4 = 2 + 8 + 16 = 26 - x_lens = x_lens - 26 - return logits, x_lens diff --git a/egs/yesno/ASR/transducer/joiner.py b/egs/yesno/ASR/transducer/joiner.py deleted file mode 100644 index 0422f8a6f..000000000 --- a/egs/yesno/ASR/transducer/joiner.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): - super().__init__() - - self.output_linear = nn.Linear(input_dim, output_dim) - - def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, C). - decoder_out: - Output from the decoder. Its shape is (N, U, C). - Returns: - Return a tensor of shape (N, T, U, C). - """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) - - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) - - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) - - logit = encoder_out + decoder_out - logit = F.relu(logit) - - output = self.output_linear(logit) - - return output diff --git a/egs/yesno/ASR/transducer/model.py b/egs/yesno/ASR/transducer/model.py deleted file mode 100644 index caf9bed37..000000000 --- a/egs/yesno/ASR/transducer/model.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" -import k2 -import torch -import torch.nn as nn -import torchaudio -import torchaudio.functional - -from icefall.utils import add_sos - -assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" -) - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: nn.Module, - decoder: nn.Module, - joiner: nn.Module, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain - one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - Returns: - Return the transducer loss. - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - decoder_out, _ = self.decoder(sos_y_padded) - - logits = self.joiner(encoder_out, decoder_out) - - # rnnt_loss requires 0 padded targets - y_padded = y.pad(mode="constant", padding_value=0) - - loss = torchaudio.functional.rnnt_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="mean", - ) - - return loss diff --git a/egs/yesno/ASR/transducer/test_decoder.py b/egs/yesno/ASR/transducer/test_decoder.py deleted file mode 100755 index 88c54f678..000000000 --- a/egs/yesno/ASR/transducer/test_decoder.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/yesno/ASR - python ./transducer/test_decoder.py -""" - -import torch -from transducer.decoder import Decoder - - -def test_decoder(): - vocab_size = 3 - blank_id = 0 - embedding_dim = 128 - num_layers = 2 - hidden_dim = 6 - N = 3 - U = 5 - - decoder = Decoder( - vocab_size=vocab_size, - embedding_dim=embedding_dim, - blank_id=blank_id, - num_layers=num_layers, - hidden_dim=hidden_dim, - embedding_dropout=0.0, - rnn_dropout=0.0, - ) - x = torch.randint(1, vocab_size, (N, U)) - rnn_out, (h, c) = decoder(x) - - assert rnn_out.shape == (N, U, hidden_dim) - assert h.shape == (num_layers, N, hidden_dim) - assert c.shape == (num_layers, N, hidden_dim) - - rnn_out, (h, c) = decoder(x, (h, c)) - assert rnn_out.shape == (N, U, hidden_dim) - assert h.shape == (num_layers, N, hidden_dim) - assert c.shape == (num_layers, N, hidden_dim) - - -def main(): - test_decoder() - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/transducer/test_encoder.py b/egs/yesno/ASR/transducer/test_encoder.py deleted file mode 100755 index 481fb558b..000000000 --- a/egs/yesno/ASR/transducer/test_encoder.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/yesno/ASR - python ./transducer/test_encoder.py -""" - -import torch -from transducer.encoder import Tdnn - - -def test_encoder(): - input_dim = 10 - output_dim = 20 - encoder = Tdnn(input_dim, output_dim) - N = 10 - T = 85 - x = torch.rand(N, T, input_dim) - x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32) - logits, logit_lens = encoder(x, x_lens) - assert logits.shape == (N, T - 26, output_dim) - assert torch.all(torch.eq(x_lens - 26, logit_lens)) - - -def main(): - test_encoder() - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/transducer/test_joiner.py b/egs/yesno/ASR/transducer/test_joiner.py deleted file mode 100755 index 2773ca319..000000000 --- a/egs/yesno/ASR/transducer/test_joiner.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/yesno/ASR - python ./transducer/test_joiner.py -""" - - -import torch -from transducer.joiner import Joiner - - -def test_joiner(): - N = 2 - T = 3 - C = 4 - U = 5 - - joiner = Joiner(C, 10) - - encoder_out = torch.rand(N, T, C) - decoder_out = torch.rand(N, U, C) - - joint = joiner(encoder_out, decoder_out) - assert joint.shape == (N, T, U, 10) - - -def main(): - test_joiner() - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/transducer/test_transducer.py b/egs/yesno/ASR/transducer/test_transducer.py deleted file mode 100755 index db7bf9c68..000000000 --- a/egs/yesno/ASR/transducer/test_transducer.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/yesno/ASR - python ./transducer/test_transducer.py -""" - - -import k2 -import torch -from transducer.decoder import Decoder -from transducer.encoder import Tdnn -from transducer.joiner import Joiner -from transducer.model import Transducer - - -def test_transducer(): - # encoder params - input_dim = 10 - output_dim = 20 - - # decoder params - vocab_size = 3 - blank_id = 0 - embedding_dim = 128 - num_layers = 2 - - encoder = Tdnn(input_dim, output_dim) - - decoder = Decoder( - vocab_size=vocab_size, - embedding_dim=embedding_dim, - blank_id=blank_id, - num_layers=num_layers, - hidden_dim=output_dim, - embedding_dropout=0.0, - rnn_dropout=0.0, - ) - - joiner = Joiner(output_dim, vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) - - y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]]) - N = y.dim0 - T = 50 - - x = torch.rand(N, T, input_dim) - x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32) - x_lens[0] = T - - loss = transducer(x, x_lens, y) - print(loss) - - -def main(): - test_transducer() - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py deleted file mode 100755 index 88866ae81..000000000 --- a/egs/yesno/ASR/transducer/train.py +++ /dev/null @@ -1,587 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import List, Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import YesNoAsrDataModule -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transducer.decoder import Decoder -from transducer.encoder import Tdnn -from transducer.joiner import Joiner -from transducer.model import Transducer - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_labels(texts: List[str]) -> k2.RaggedTensor: - """ - Args: - texts: - A list of transcripts. Each transcript contains spaces separated - "NO" or "YES". - Returns: - Return a ragged tensor containing the corresponding word ID. - """ - # blank is 0 - word2id = {"YES": 1, "NO": 2} - word_ids = [] - for t in texts: - words = t.split() - ids = [word2id[w] for w in words] - word_ids.append(ids) - - return k2.RaggedTensor(word_ids) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=200, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer/exp", - help="Directory to save results", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - """ - params = AttributeDict( - { - "lr": 1e-3, - "feature_dim": 23, - "weight_decay": 1e-6, - "start_epoch": 0, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 20, - "valid_interval": 10, - # encoder/decoder params - "vocab_size": 3, # blank, yes, no - "blank_id": 0, - "embedding_dim": 32, - "hidden_dim": 16, - "num_decoder_layers": 4, - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Tdnn in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - feature_lens = batch["supervisions"]["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - labels = get_labels(texts).to(device) - - with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=labels) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = feature.size(0) - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def get_transducer_model(params: AttributeDict): - encoder = Tdnn( - num_features=params.feature_dim, - output_dim=params.hidden_dim, - ) - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.hidden_dim, - embedding_dropout=0.4, - rnn_dropout=0.4, - ) - joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) - - return transducer - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - params["env_info"] = get_env_info() - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"device: {device}") - - model = get_transducer_model(params) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - model.device = device - - optimizer = optim.Adam( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - yes_no = YesNoAsrDataModule(args) - train_dl = yes_no.train_dataloaders() - - # There are only 60 waves: 30 files are used for training - # and the remaining 30 files are used for testing. - # We use test data as validation. - valid_dl = yes_no.test_dataloaders() - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=None, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - YesNoAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/icefall/__init__.py b/icefall/__init__.py old mode 100644 new mode 100755 diff --git a/icefall/ali.py b/icefall/ali.py old mode 100644 new mode 100755 diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py old mode 100644 new mode 100755 index d9659c2dd..dfb55216f --- a/icefall/bpe_graph_compiler.py +++ b/icefall/bpe_graph_compiler.py @@ -50,7 +50,7 @@ class BpeCtcTrainingGraphCompiler(object): sp = spm.SentencePieceProcessor() sp.load(str(model_file)) self.sp = sp - self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + # self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt") self.device = device self.sos_id = self.sp.piece_to_id(sos_token) diff --git a/icefall/byte_utils.py b/icefall/byte_utils.py old mode 100644 new mode 100755 diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py old mode 100644 new mode 100755 diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py old mode 100644 new mode 100755 diff --git a/icefall/context_graph.py b/icefall/context_graph.py old mode 100644 new mode 100755 diff --git a/icefall/ctc/.gitignore b/icefall/ctc/.gitignore old mode 100644 new mode 100755 diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md old mode 100644 new mode 100755 diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py old mode 100644 new mode 100755 diff --git a/icefall/ctc/prepare_lang.py b/icefall/ctc/prepare_lang.py old mode 100644 new mode 100755 diff --git a/icefall/ctc/topo.py b/icefall/ctc/topo.py old mode 100644 new mode 100755 diff --git a/icefall/ctc/utils.py b/icefall/ctc/utils.py old mode 100644 new mode 100755 diff --git a/icefall/dataset/__init__.py b/icefall/dataset/__init__.py old mode 100644 new mode 100755 diff --git a/icefall/dataset/datamodule.py b/icefall/dataset/datamodule.py old mode 100644 new mode 100755 diff --git a/icefall/decode.py b/icefall/decode.py old mode 100644 new mode 100755 diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py old mode 100644 new mode 100755 diff --git a/icefall/dist.py b/icefall/dist.py old mode 100644 new mode 100755 diff --git a/icefall/env.py b/icefall/env.py old mode 100644 new mode 100755 diff --git a/icefall/err.py b/icefall/err.py old mode 100644 new mode 100755 diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py old mode 100644 new mode 100755 diff --git a/icefall/hooks.py b/icefall/hooks.py old mode 100644 new mode 100755 diff --git a/icefall/lexicon.py b/icefall/lexicon.py old mode 100644 new mode 100755 index 22e1b78bb..88351f19a --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -162,20 +162,20 @@ class Lexicon(object): """ lang_dir = Path(lang_dir) self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + # self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - if (lang_dir / "Linv.pt").exists(): - logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt") - L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt")) - else: - logging.info("Converting L.pt to Linv.pt") - L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt")) - L_inv = k2.arc_sort(L.invert()) - torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") + # if (lang_dir / "Linv.pt").exists(): + # logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt") + # L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt")) + # else: + # logging.info("Converting L.pt to Linv.pt") + # L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt")) + # L_inv = k2.arc_sort(L.invert()) + # torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") - # We save L_inv instead of L because it will be used to intersect with - # transcript FSAs, both of whose labels are word IDs. - self.L_inv = L_inv + # # We save L_inv instead of L because it will be used to intersect with + # # transcript FSAs, both of whose labels are word IDs. + # self.L_inv = L_inv self.disambig_pattern = disambig_pattern @property diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py old mode 100644 new mode 100755 diff --git a/icefall/mmi.py b/icefall/mmi.py old mode 100644 new mode 100755 diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py old mode 100644 new mode 100755 diff --git a/icefall/ngram_lm.py b/icefall/ngram_lm.py old mode 100644 new mode 100755 diff --git a/icefall/otc_graph_compiler.py b/icefall/otc_graph_compiler.py old mode 100644 new mode 100755 diff --git a/icefall/otc_phone_graph_compiler.py b/icefall/otc_phone_graph_compiler.py old mode 100644 new mode 100755 diff --git a/icefall/profiler.py b/icefall/profiler.py old mode 100644 new mode 100755 diff --git a/icefall/rnn_lm/.gitignore b/icefall/rnn_lm/.gitignore old mode 100644 new mode 100755 diff --git a/icefall/rnn_lm/__init__.py b/icefall/rnn_lm/__init__.py old mode 100644 new mode 100755 diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py old mode 100644 new mode 100755 diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py old mode 100644 new mode 100755 diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py old mode 100644 new mode 100755 diff --git a/icefall/transformer_lm/__init__.py b/icefall/transformer_lm/__init__.py old mode 100644 new mode 100755 diff --git a/icefall/transformer_lm/attention.py b/icefall/transformer_lm/attention.py old mode 100644 new mode 100755 diff --git a/icefall/transformer_lm/compute_perplexity.py b/icefall/transformer_lm/compute_perplexity.py old mode 100644 new mode 100755 diff --git a/icefall/transformer_lm/encoder.py b/icefall/transformer_lm/encoder.py old mode 100644 new mode 100755 diff --git a/icefall/transformer_lm/export.py b/icefall/transformer_lm/export.py old mode 100644 new mode 100755 diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py old mode 100644 new mode 100755 diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py old mode 100644 new mode 100755 diff --git a/icefall/utils.py b/icefall/utils.py old mode 100644 new mode 100755 diff --git a/pyproject.toml b/pyproject.toml old mode 100644 new mode 100755 diff --git a/requirements-ci.txt b/requirements-ci.txt old mode 100644 new mode 100755 diff --git a/requirements-tts.txt b/requirements-tts.txt old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py old mode 100644 new mode 100755 diff --git a/test/test_decode.py b/test/test_decode.py old mode 100644 new mode 100755 diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py old mode 100644 new mode 100755 diff --git a/test/test_utils.py b/test/test_utils.py old mode 100644 new mode 100755